Skip to content

Commit 043c06f

Browse files
authored
Merge pull request rapidsai#951 from divyegala/revert-pr-915
Revert "Fix kmeans::predict argument order (rapidsai#915)"
1 parent e00fabe commit 043c06f

4 files changed

Lines changed: 13 additions & 77 deletions

File tree

cpp/include/cuvs/cluster/kmeans.hpp

Lines changed: 4 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -519,26 +519,10 @@ void predict(raft::resources const& handle,
519519
raft::device_matrix_view<const float, int> X,
520520
std::optional<raft::device_vector_view<const float, int>> sample_weight,
521521
raft::device_matrix_view<const float, int> centroids,
522-
bool normalize_weight,
523522
raft::device_vector_view<int, int> labels,
523+
bool normalize_weight,
524524
raft::host_scalar_view<float> inertia);
525525

526-
// This overload is retained for backward compatibility.
527-
[[deprecated(
528-
"The argument order of kmeans::predict has been corrected. Please use the new function "
529-
"instead.")]]
530-
inline void predict(raft::resources const& handle,
531-
const kmeans::params& params,
532-
raft::device_matrix_view<const float, int> X,
533-
std::optional<raft::device_vector_view<const float, int>> sample_weight,
534-
raft::device_matrix_view<const float, int> centroids,
535-
raft::device_vector_view<int, int> labels,
536-
bool normalize_weight,
537-
raft::host_scalar_view<float> inertia)
538-
{
539-
predict(handle, params, X, sample_weight, centroids, normalize_weight, labels, inertia);
540-
}
541-
542526
/**
543527
* @brief Predict the closest cluster each sample in X belongs to.
544528
*
@@ -593,26 +577,10 @@ void predict(raft::resources const& handle,
593577
raft::device_matrix_view<const float, int> X,
594578
std::optional<raft::device_vector_view<const float, int>> sample_weight,
595579
raft::device_matrix_view<const float, int> centroids,
596-
bool normalize_weight,
597580
raft::device_vector_view<int64_t, int> labels,
581+
bool normalize_weight,
598582
raft::host_scalar_view<float> inertia);
599583

600-
// This overload is retained for backward compatibility.
601-
[[deprecated(
602-
"The argument order of kmeans::predict has been corrected. Please use the new function "
603-
"instead.")]]
604-
inline void predict(raft::resources const& handle,
605-
const kmeans::params& params,
606-
raft::device_matrix_view<const float, int> X,
607-
std::optional<raft::device_vector_view<const float, int>> sample_weight,
608-
raft::device_matrix_view<const float, int> centroids,
609-
raft::device_vector_view<int64_t, int> labels,
610-
bool normalize_weight,
611-
raft::host_scalar_view<float> inertia)
612-
{
613-
predict(handle, params, X, sample_weight, centroids, normalize_weight, labels, inertia);
614-
}
615-
616584
/**
617585
* @brief Predict the closest cluster each sample in X belongs to.
618586
*
@@ -667,26 +635,10 @@ void predict(raft::resources const& handle,
667635
raft::device_matrix_view<const double, int> X,
668636
std::optional<raft::device_vector_view<const double, int>> sample_weight,
669637
raft::device_matrix_view<const double, int> centroids,
670-
bool normalize_weight,
671638
raft::device_vector_view<int, int> labels,
639+
bool normalize_weight,
672640
raft::host_scalar_view<double> inertia);
673641

674-
// This overload is retained for backward compatibility.
675-
[[deprecated(
676-
"The argument order of kmeans::predict has been corrected. Please use the new function "
677-
"instead.")]]
678-
inline void predict(raft::resources const& handle,
679-
const kmeans::params& params,
680-
raft::device_matrix_view<const double, int> X,
681-
std::optional<raft::device_vector_view<const double, int>> sample_weight,
682-
raft::device_matrix_view<const double, int> centroids,
683-
raft::device_vector_view<int, int> labels,
684-
bool normalize_weight,
685-
raft::host_scalar_view<double> inertia)
686-
{
687-
predict(handle, params, X, sample_weight, centroids, normalize_weight, labels, inertia);
688-
}
689-
690642
/**
691643
* @brief Predict the closest cluster each sample in X belongs to.
692644
*
@@ -741,26 +693,10 @@ void predict(raft::resources const& handle,
741693
raft::device_matrix_view<const double, int> X,
742694
std::optional<raft::device_vector_view<const double, int>> sample_weight,
743695
raft::device_matrix_view<const double, int> centroids,
744-
bool normalize_weight,
745696
raft::device_vector_view<int64_t, int> labels,
697+
bool normalize_weight,
746698
raft::host_scalar_view<double> inertia);
747699

748-
// This overload is retained for backward compatibility.
749-
[[deprecated(
750-
"The argument order of kmeans::predict has been corrected. Please use the new function "
751-
"instead.")]]
752-
inline void predict(raft::resources const& handle,
753-
const kmeans::params& params,
754-
raft::device_matrix_view<const double, int> X,
755-
std::optional<raft::device_vector_view<const double, int>> sample_weight,
756-
raft::device_matrix_view<const double, int> centroids,
757-
raft::device_vector_view<int64_t, int> labels,
758-
bool normalize_weight,
759-
raft::host_scalar_view<double> inertia)
760-
{
761-
predict(handle, params, X, sample_weight, centroids, normalize_weight, labels, inertia);
762-
}
763-
764700
/**
765701
* @brief Predict the closest cluster each sample in X belongs to.
766702
*

cpp/src/cluster/kmeans.cuh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -163,8 +163,8 @@ void predict(raft::resources const& handle,
163163
raft::device_matrix_view<const DataT, IndexT> X,
164164
std::optional<raft::device_vector_view<const DataT, IndexT>> sample_weight,
165165
raft::device_matrix_view<const DataT, IndexT> centroids,
166-
bool normalize_weight,
167166
raft::device_vector_view<IndexT, IndexT> labels,
167+
bool normalize_weight,
168168
raft::host_scalar_view<DataT> inertia)
169169
{
170170
cuvs::cluster::kmeans::detail::kmeans_predict<DataT, IndexT>(

cpp/src/cluster/kmeans_predict_double.cu

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,26 +24,26 @@ void predict(raft::resources const& handle,
2424
raft::device_matrix_view<const double, int> X,
2525
std::optional<raft::device_vector_view<const double, int>> sample_weight,
2626
raft::device_matrix_view<const double, int> centroids,
27-
bool normalize_weight,
2827
raft::device_vector_view<int, int> labels,
28+
bool normalize_weight,
2929
raft::host_scalar_view<double> inertia)
3030

3131
{
3232
cuvs::cluster::kmeans::predict<double, int>(
33-
handle, params, X, sample_weight, centroids, normalize_weight, labels, inertia);
33+
handle, params, X, sample_weight, centroids, labels, normalize_weight, inertia);
3434
}
3535

3636
void predict(raft::resources const& handle,
3737
const kmeans::params& params,
3838
raft::device_matrix_view<const double, int> X,
3939
std::optional<raft::device_vector_view<const double, int>> sample_weight,
4040
raft::device_matrix_view<const double, int> centroids,
41-
bool normalize_weight,
4241
raft::device_vector_view<int64_t, int> labels,
42+
bool normalize_weight,
4343
raft::host_scalar_view<double> inertia)
4444

4545
{
4646
cuvs::cluster::kmeans::predict<double, int64_t>(
47-
handle, params, X, sample_weight, centroids, normalize_weight, labels, inertia);
47+
handle, params, X, sample_weight, centroids, labels, normalize_weight, inertia);
4848
}
4949
} // namespace cuvs::cluster::kmeans

cpp/src/cluster/kmeans_predict_float.cu

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,25 +24,25 @@ void predict(raft::resources const& handle,
2424
raft::device_matrix_view<const float, int> X,
2525
std::optional<raft::device_vector_view<const float, int>> sample_weight,
2626
raft::device_matrix_view<const float, int> centroids,
27-
bool normalize_weight,
2827
raft::device_vector_view<int, int> labels,
28+
bool normalize_weight,
2929
raft::host_scalar_view<float> inertia)
3030

3131
{
3232
cuvs::cluster::kmeans::predict<float, int>(
33-
handle, params, X, sample_weight, centroids, normalize_weight, labels, inertia);
33+
handle, params, X, sample_weight, centroids, labels, normalize_weight, inertia);
3434
}
3535
void predict(raft::resources const& handle,
3636
const kmeans::params& params,
3737
raft::device_matrix_view<const float, int> X,
3838
std::optional<raft::device_vector_view<const float, int>> sample_weight,
3939
raft::device_matrix_view<const float, int> centroids,
40-
bool normalize_weight,
4140
raft::device_vector_view<int64_t, int> labels,
41+
bool normalize_weight,
4242
raft::host_scalar_view<float> inertia)
4343

4444
{
4545
cuvs::cluster::kmeans::predict<float, int64_t>(
46-
handle, params, X, sample_weight, centroids, normalize_weight, labels, inertia);
46+
handle, params, X, sample_weight, centroids, labels, normalize_weight, inertia);
4747
}
4848
} // namespace cuvs::cluster::kmeans

0 commit comments

Comments
 (0)