Skip to content
This repository was archived by the owner on Jun 8, 2023. It is now read-only.
Open
12 changes: 10 additions & 2 deletions csrc/litert.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -472,9 +472,10 @@ DIOPI_RT_API diopiError_t _diopiCreateContext(diopiContextHandle_t* ctx) {
return diopiSuccess;
}

DIOPI_RT_API diopiError_t _diopiDestroyContext(diopiContextHandle_t ctx) {
DIOPI_RT_API diopiError_t _diopiDestroyContext(diopiContextHandle_t* ctx) {
diopi_log("destroy a Context instance: %16p", ctx);
delete ctx;
delete *ctx;
*ctx = nullptr;
return diopiSuccess;
}

Expand Down Expand Up @@ -536,6 +537,13 @@ DIOPI_RT_API diopiError_t diopiFinalize() {
return diopiSuccess;
}

DIOPI_RT_API diopiError_t _diopiDeviceStreamSync(diopiContextHandle_t ctx) {
diopiStreamHandle_t stream;
diopiGetStream(ctx, &stream);
synchronize_stream_func(stream);
return diopiSuccess;
}

DIOPI_RT_API diopiError_t _diopiTensorCopyFromBuffer(diopiContextHandle_t ctx,
const void* src,
diopiTensorHandle_t tensor) {
Expand Down
27 changes: 16 additions & 11 deletions python/conformance/conformance_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,13 @@
from .config import Config
from .utils import logger, FunctionNotImplementedError, DiopiException
from .utils import need_process_func, glob_vars, nhwc_op, dtype_op
from .diopi_runtime import Tensor, compute_nhwc_stride
from .diopi_runtime import Tensor, compute_nhwc_stride, Context
from .utils import save_precision, record, write_precision
from .utils import get_saved_pth_list, get_data_from_file
from .utils import cfg_file_name


def convert_input_tensors(function_paras: dict, test_tag: list, nhwc_list=[], dtype_list=[], filter_dtype_str_list=[]):
def convert_input_tensors(ctx, function_paras: dict, test_tag: list, nhwc_list=[], dtype_list=[], filter_dtype_str_list=[]):
tensor_info = []
for para in function_paras["kwargs"].keys():
tensor = function_paras['kwargs'][para]
Expand Down Expand Up @@ -46,13 +46,13 @@ def convert_input_tensors(function_paras: dict, test_tag: list, nhwc_list=[], dt
raise DiopiException(f"Skipped: {tensor.dtype} Tensor skipped for test")
if tensor is not None and str(tensor.dtype) not in test_tag:
test_tag.append(str(tensor.dtype))
function_paras['kwargs'][para] = Tensor.from_numpy(tensor)
function_paras['kwargs'][para] = Tensor.from_numpy(ctx, tensor)
tensor_info.append((para, str(tensor.dtype), str(tensor.shape)))

if para == "tensors":
tensors = function_paras['kwargs'][para]
for idx, ele in enumerate(tensors):
tensors[idx] = Tensor.from_numpy(ele)
tensors[idx] = Tensor.from_numpy(ctx, ele)
if ele is not None and str(ele.dtype) not in test_tag:
test_tag.append(str(ele.dtype))
function_paras['kwargs'][para] = tensors
Expand Down Expand Up @@ -385,15 +385,17 @@ def unlink_device_configs():
if data["cfg"].get("is_inplace", False):
func_call_list.append(f"{module}.{test_func_name}(**kwargs, inplace=True)")

ctx = Context()
for func_call in func_call_list:
if "inplace=True" in func_call:
if test_tag and test_tag[-1] == 'backward':
test_tag.pop()
test_tag.append("inplace")
try:
info = convert_input_tensors(function_paras, test_tag, nhwc_list, dtype_list, filter_dtype_str_list)
info = convert_input_tensors(ctx, function_paras, test_tag, nhwc_list, dtype_list, filter_dtype_str_list)
tensor_info = info if info else tensor_info
output = eval(func_call)
ctx.streamSync()
sum_to_compare = True if 'sorted' in kwargs and ~kwargs['sorted'] else False
passed = compare_with_gen_output(output, data['cfg'], output_reference, sum_to_compare) \
if need_output else True
Expand All @@ -410,13 +412,13 @@ def unlink_device_configs():
logger.error(f"output_reference:\n{output_reference}")
logger.error(f"output:\n{output}")
except FunctionNotImplementedError as e:
logger.error(f"NotImplemented: {e}")
logger.error(f"NotImplemented: {e} in {func_call}")
continue
except AttributeError as e:
logger.error(f"AttributeError: {e}")
logger.error(f"AttributeError: {e} in {func_call}")
continue
except Exception as e:
logger.error(f"{e}")
logger.error(f"{e} in {func_call}")
continue

write_precision(data["cfg"], cfg_func_name, passed)
Expand All @@ -443,6 +445,7 @@ def unlink_device_configs():

try:
grad_input = eval(f"F.{cfg_func_name}_backward(**kwargs, **backward_para)")
ctx.streamSync()
passed = compare_with_gen_output(grad_input, data['cfg'], backward_out_reference)
if passed:
logger.info(f"Run diopi_functions.{cfg_func_name}_backward succeed")
Expand All @@ -460,8 +463,10 @@ def unlink_device_configs():
logger.error(f"grad:\n{grad_input}")
write_precision(data["cfg"], cfg_func_name + '_bp', passed)
except FunctionNotImplementedError as e:
logger.error(f"NotImplemented: {e}")
logger.error(f"NotImplemented: {e} in {func_call}")
except AttributeError as e:
logger.error(f"AttributeError: {e}")
logger.error(f"AttributeError: {e} in {func_call}")
except Exception as e:
logger.error(f"Failed: {e}")
logger.error(f"Failed: {e} in {func_call}")
# do not forget to clear the ctx.
ctx.clear()
Loading