Skip to content

Commit a157105

Browse files
committed
ghjfgh
1 parent da71051 commit a157105

File tree

4 files changed

+122
-49
lines changed

4 files changed

+122
-49
lines changed

RATapi/wrappers.py

Lines changed: 47 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
"""Wrappers for the interface between RATapi and MATLAB custom files."""
22
import os
33
import pathlib
4+
import platform
5+
import shutil
46
from typing import Callable
57

68
import numpy as np
@@ -9,11 +11,36 @@
911
import RATapi.rat_core
1012

1113

12-
def find_matlab():
13-
pass
14+
MATLAB_PATH_FILE = os.path.join(os.path.dirname(os.path.realpath(__file__)), "matlab.txt")
1415

1516

17+
def set_matlab_path(path):
18+
if not path:
19+
return
20+
21+
path = pathlib.Path(path)
22+
if not path.is_dir():
23+
path = path.parent
24+
25+
if path.stem != 'bin':
26+
path = path / 'bin'
27+
28+
if platform.system() == "Windows":
29+
arch = "win64"
30+
elif platform.system() == "Darwin":
31+
arch = "maci64" if (path / "maci64").exists() else "maca64"
32+
else:
33+
arch = "glnxa64"
34+
35+
path = path / arch
36+
if not path.exists():
37+
raise FileNotFoundError(f"The expected MATLAB folders were in found at the path: {path}")
38+
39+
with open(MATLAB_PATH_FILE, "w") as path_file:
40+
path_file.write(path.as_posix())
1641

42+
return path.as_posix()
43+
1744

1845
def start_matlab():
1946
"""Start MATLAB asynchronously and returns a future to retrieve the engine later.
@@ -24,18 +51,23 @@ def start_matlab():
2451
A custom matlab engine wrapper.
2552
2653
"""
27-
28-
path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "matlab.txt")
29-
# if pathlib(matlab_path).is_file()
30-
with open(path) as path_file:
31-
matlab_path = path_file.read()
54+
try:
55+
with open(MATLAB_PATH_FILE) as path_file:
56+
matlab_path = path_file.read()
57+
except FileNotFoundError:
58+
matlab_path = ""
3259

33-
os.environ["RAT_PATH"] = dir_path, "")
34-
os.environ["MATLAB_INSTALL_DIR"] += os.pathsep + "C:\\Program Files\\MATLAB\\R2023a\\bin\\win64"
35-
engine = RATapi.rat_core.MatlabEngine()
36-
engine.start()
60+
if not matlab_path:
61+
matlab_path = set_matlab_path(shutil.which("matlab"))
62+
if matlab_path is None:
63+
matlab_path = ""
64+
65+
if matlab_path:
66+
os.environ["PATH"] += os.pathsep + matlab_path
67+
engine = RATapi.rat_core.MatlabEngine()
68+
engine.start()
3769

38-
return engine
70+
return engine
3971

4072

4173

@@ -50,6 +82,9 @@ class MatlabWrapper:
5082
engine = start_matlab()
5183

5284
def __init__(self, filename) -> None:
85+
if self.engine is None:
86+
raise ValueError("MATLAB is not found. Please use `set_matlab_path` to set the location of your MATLAB installation") from None
87+
5388
path = pathlib.Path(filename)
5489
self.engine.cd(str(path.parent))
5590
self.engine.setFunction(path.stem)

cpp/matlab/matlabCallerImpl.hpp

Lines changed: 52 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,9 @@ class MatlabCaller
1818

1919
void setEngine(){
2020
if (!(matlabPtr = engOpen(""))) {
21-
throw("\nCan't start MATLAB engine\n");
21+
throw std::runtime_error("\nCan't start MATLAB engine\n");
2222
}
23+
engSetVisible(matlabPtr, 0);
2324
};
2425

2526
void startMatlab(){
@@ -31,16 +32,52 @@ class MatlabCaller
3132
dirChanged = true;
3233
};
3334

35+
void call(std::string functionName, std::vector<double>& xdata, std::vector<double>& params, std::vector<double>& output)
36+
{
37+
if (!this->matlabPtr)
38+
this->setEngine();
39+
40+
if (dirChanged){
41+
std::string cdCmd = "cd('" + (this->currentDirectory + "')");
42+
engEvalString(this->matlabPtr, cdCmd.c_str());
43+
}
44+
45+
dirChanged = false;
46+
mxArray *XDATA = mxCreateDoubleMatrix(1,xdata.size(),mxREAL);
47+
memcpy(mxGetPr(XDATA), &xdata[0], xdata.size()*sizeof(double));
48+
engPutVariable(this->matlabPtr, "xdata", XDATA);
49+
mxArray *PARAMS = mxCreateDoubleMatrix(1,params.size(),mxREAL);
50+
memcpy(mxGetPr(PARAMS), &params[0], params.size()*sizeof(double));
51+
engPutVariable(this->matlabPtr, "params", PARAMS);
52+
53+
std::string customCmd = "[output, subRough] = " + (functionName + "(params, bulkIn, bulkOut, contrast)");
54+
engPutVariable(this->matlabPtr, "myFunction", mxCreateString(customCmd.c_str()));
55+
engOutputBuffer(this->matlabPtr, NULL, 0);
56+
engEvalString(this->matlabPtr, "eval(myFunction)");
57+
58+
mxArray *matOutput = engGetVariable(this->matlabPtr, "output");
59+
if (matOutput == NULL)
60+
{
61+
throw std::runtime_error("ERROR: Results could not be extracted from MATLAB engine.");
62+
}
63+
64+
const mwSize* dims = mxGetDimensions(matOutput);
65+
double* s = (double *)mxGetData(matOutput);
66+
for (int i=0; i < dims[0] * dims[1]; i++)
67+
output.push_back(s[i]);
68+
};
69+
3470
void call(std::string functionName, std::vector<double>& params, std::vector<double>& bulkIn,
3571
std::vector<double>& bulkOut, int contrast, int domain, std::vector<double>& output, double* outputSize, double* rough)
3672
{
3773
if (!this->matlabPtr)
3874
this->setEngine();
75+
3976
if (dirChanged){
4077
std::string cdCmd = "cd('" + (this->currentDirectory + "')");
4178
engEvalString(this->matlabPtr, cdCmd.c_str());
4279
}
43-
//this->matlabPtr->feval(u"cd", factory.createCharArray(this->currentDirectory));
80+
4481
dirChanged = false;
4582
mxArray *PARAMS = mxCreateDoubleMatrix(1,params.size(),mxREAL);
4683
memcpy(mxGetPr(PARAMS), &params[0], params.size()*sizeof(double));
@@ -52,39 +89,34 @@ class MatlabCaller
5289
memcpy((void *)mxGetPr(BULKOUT), &bulkOut[0], bulkOut.size()*sizeof(double));
5390
engPutVariable(this->matlabPtr, "bulkOut", BULKOUT);
5491
mxArray *CONTRAST = mxCreateDoubleScalar(contrast);
55-
// memcpy((void *)mxGetPr(CONTRAST), &contrast, 1*sizeof(double));
5692
engPutVariable(this->matlabPtr, "contrast", CONTRAST);
57-
// if (domain > 0)
58-
// args.push_back(factory.createScalar<int>(domain));
59-
std::string customCmd = "[output, subRough] = " + (functionName + "(params, bulkIn, bulkOut, contrast)");
93+
std::string customCmd;
94+
if (domain > 0){
95+
mxArray *DOMAIN_NUM = mxCreateDoubleScalar(domain);
96+
engPutVariable(this->matlabPtr, "domain", DOMAIN_NUM);
97+
customCmd = "[output, subRough] = " + (functionName + "(params, bulkIn, bulkOut, contrast, domain)");
98+
}
99+
else {
100+
customCmd = "[output, subRough] = " + (functionName + "(params, bulkIn, bulkOut, contrast)");
101+
}
60102
engPutVariable(this->matlabPtr, "myFunction", mxCreateString(customCmd.c_str()));
61103
engOutputBuffer(this->matlabPtr, NULL, 0);
62-
//auto start = high_resolution_clock::now();
63-
// std::vector<matlab::data::Array> results = this->matlabPtr->feval(functionName, 2, args);
64104
engEvalString(this->matlabPtr, "eval(myFunction)");
65-
//auto stop = high_resolution_clock::now();
66-
//auto duration = duration_cast<microseconds>(stop - start);
67-
//std::cout << duration.count() << "Usec" << std::endl;
68105

69106
mxArray *matOutput = engGetVariable(this->matlabPtr, "output");
70-
if (matOutput == NULL)
71-
{
72-
throw("FAILED!");
73-
}
74-
mxArray *subRough = engGetVariable(this->matlabPtr, "subRough");
75-
if (subRough == NULL)
107+
mxArray *subRough = engGetVariable(this->matlabPtr, "subRough");
108+
if (matOutput == NULL || subRough == NULL)
76109
{
77-
throw("FAILED!");
110+
throw std::runtime_error("ERROR: Results could not be extracted from MATLAB engine.");
78111
}
112+
79113
*rough = (double)mxGetScalar(subRough);
80114
const mwSize* dims = mxGetDimensions(matOutput);
81115
outputSize[0] = (double) dims[0];
82116
outputSize[1] = (double) dims[1];
83-
// output.push_back((double) matOutput[i]);
84117
double* s = (double *)mxGetData(matOutput);
85118
for (int i=0; i < dims[0] * dims[1]; i++)
86119
output.push_back(s[i]);
87-
//std::memcpy(output, (double *)mxGetData(matOutput), mxGetNumberOfElements(matOutput)* mxGetElementSize(matOutput));
88120
};
89121

90122
static MatlabCaller* get_instance()

cpp/rat.cpp

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -67,17 +67,17 @@ class MatlabEngine
6767

6868
py::list invoke(std::vector<double>& xdata, std::vector<double>& params)
6969
{
70-
// try{
71-
std::vector<double> output;
72-
73-
// auto func = library->get_function<void(std::vector<double>&, std::vector<double>&, std::vector<double>&)>(functionName);
74-
// func(xdata, params, output);
75-
76-
return py::cast(output);
70+
try{
71+
std::vector<double> output;
72+
73+
auto func = library->get_function<void(std::string, std::vector<double>&, std::vector<double>&, std::vector<double>&)>("callFunction");
74+
func(functionName, xdata, params, output);
75+
76+
return py::cast(output);
7777

78-
// }catch (const dylib::symbol_error &) {
79-
// throw std::runtime_error("failed to get dynamic library symbol for " + functionName);
80-
// }
78+
}catch (const dylib::symbol_error &) {
79+
throw std::runtime_error("failed to run MATLAB function: " + functionName);
80+
}
8181
};
8282

8383
py::tuple invoke(std::vector<double>& params, std::vector<double>& bulkIn, std::vector<double>& bulkOut, int contrast, int domain=DEFAULT_DOMAIN)
@@ -86,24 +86,24 @@ class MatlabEngine
8686
std::vector<double> tempOutput;
8787
double *outputSize = new double[2];
8888
double roughness = 0.0;
89+
8990
auto func = library->get_function<void(std::string, std::vector<double>&, std::vector<double>&, std::vector<double>&,
9091
int, int, std::vector<double>&, double*, double*)>("callFunction");
9192
func(functionName, params, bulkIn, bulkOut, contrast + 1, domain + 1, tempOutput, outputSize, &roughness);
92-
9393
py::list output;
9494
for (int32_T idx1{0}; idx1 < outputSize[0]; idx1++)
9595
{
9696
py::list rows;
9797
for (int32_T idx2{0}; idx2 < outputSize[1]; idx2++)
9898
{
99-
rows.append(tempOutput[(int32_T)outputSize[1] * idx1 + idx2]);
99+
rows.append(tempOutput[(int32_T)outputSize[0] * idx2 + idx1]);
100100
}
101101
output.append(rows);
102102
}
103103
return py::make_tuple(output, roughness);
104104

105105
}catch (const dylib::symbol_error &) {
106-
throw std::runtime_error("failed to get dynamic library symbol for " + functionName);
106+
throw std::runtime_error("failed to run MATLAB function: " + functionName);
107107
}
108108
};
109109
};

setup.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -104,9 +104,12 @@ def run(self):
104104

105105
if self.inplace:
106106
obj_name = get_shared_object_name(libevent[0])
107-
src = f"{build_py.build_lib}/{PACKAGE_NAME}/{obj_name}"
108-
dest = f"{build_py.get_package_dir(PACKAGE_NAME)}/{obj_name}"
109-
build_py.copy_file(src, dest)
107+
build_py.copy_file(f"{build_py.build_lib}/{PACKAGE_NAME}/{obj_name}",
108+
f"{build_py.get_package_dir(PACKAGE_NAME)}/{obj_name}")
109+
110+
obj_name = get_shared_object_name(libmatlab[0])
111+
build_py.copy_file(f"{build_py.build_lib}/{PACKAGE_NAME}/{obj_name}",
112+
f"{build_py.get_package_dir(PACKAGE_NAME)}/{obj_name}")
110113

111114

112115
class BuildClib(build_clib):
@@ -159,7 +162,10 @@ def build_libraries(self, libraries):
159162
if self.matlab_install_dir:
160163
link_libraries.extend(["libeng", "libmx"])
161164
if platform.system() == "Windows":
162-
link_library_dirs.append(f"{self.matlab_install_dir}/extern/lib/win64/microsoft")
165+
link_library_dirs.append(f"{self.matlab_install_dir}/bin/win64")
166+
elif platform.system() == "Linux":
167+
link_library_dirs.append(f"{self.matlab_install_dir}/bin//microsoft")
168+
163169

164170
self.compiler.link_shared_object(
165171
objects,

0 commit comments

Comments
 (0)