Skip to content

Commit 2a5c78f

Browse files
authored
Transition default GPU data format from dpctl to dpnp (#202)
* Transition default GPU data format from dpctl to dpnp * deprecate dpctl tensors in sklbench
1 parent b8e821c commit 2a5c78f

File tree

4 files changed

+16
-2
lines changed

4 files changed

+16
-2
lines changed

configs/common/sklearn.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
"estimator_params": { "n_jobs": "[REMOVE]" }
2020
},
2121
"data": {
22-
"format": "dpctl",
22+
"format": "dpnp",
2323
"order": "C",
2424
"distributed_split": "rank_based"
2525
},

sklbench/benchmarks/sklearn_estimator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -363,7 +363,7 @@ def dataframe_function(x):
363363
for i in range(n_batches):
364364
method_instance(x.iloc[i * batch_size : (i + 1) * batch_size])
365365

366-
if "ndarray" in str(type(data_args[0])):
366+
if "array" in str(type(data_args[0])):
367367
return ndarray_function
368368
elif "DataFrame" in str(type(data_args[0])):
369369
return dataframe_function

sklbench/datasets/transformer.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
# ===============================================================================
1616

1717
import os
18+
import warnings
1819

1920
import numpy as np
2021
import pandas as pd
@@ -47,6 +48,12 @@ def convert_data(data, dformat: str, order: str, dtype: str, device: str = None)
4748

4849
return dpnp.array(data, dtype=dtype, order=order, device=device)
4950
elif dformat == "dpctl":
51+
warnings.warn(
52+
"dpctl tensors are deprecated and support for them "
53+
"in scikit-learn_bench will be removed. "
54+
"Consider using dpnp arrays instead.",
55+
FutureWarning,
56+
)
5057
import dpctl.tensor
5158

5259
return dpctl.tensor.asarray(data, dtype=dtype, order=order, device=device)

sklbench/utils/common.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import json
2121
import re
2222
import subprocess as sp
23+
import warnings
2324
from pprint import pformat
2425
from shutil import get_terminal_size
2526
from typing import Any, Dict, List, Tuple, Union
@@ -214,6 +215,12 @@ def convert_to_numpy(a, dp_compat=False) -> np.ndarray:
214215

215216
return dpnp.asnumpy(a)
216217
elif "dpctl" in str(type(a)):
218+
warnings.warn(
219+
"dpctl tensors are deprecated and support for them "
220+
"in scikit-learn_bench will be removed. "
221+
"Consider using dpnp arrays instead.",
222+
FutureWarning,
223+
)
217224
import dpctl.tensor
218225

219226
return dpctl.tensor.to_numpy(a)

0 commit comments

Comments
 (0)