Skip to content
3 changes: 3 additions & 0 deletions include/Accessor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,9 @@ namespace cytnx {
Accessor(const Accessor &rhs);
// copy assignment:
Accessor &operator=(const Accessor &rhs);

// check equality
bool operator==(const Accessor &rhs) const;
///@endcond

int type() const { return this->_type; }
Expand Down
14 changes: 12 additions & 2 deletions include/Tensor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -1008,10 +1008,13 @@ namespace cytnx {
/**
@brief get elements using Accessor (C++ API) / slices (python API)
@param[in] accessors the Accessor (C++ API) / slices (python API) to get the elements.
@param[out] removed the indices that were removed from the original shape of the Tensor are
pushed to the end of this vector. Usually, an empty vector should be passed.
@return [Tensor]
@see \link cytnx::Accessor Accessor\endlink for cordinate with Accessor in C++ API.
@note
1. the return will be a new Tensor instance, which not share memory with the current Tensor.
The return will be a new Tensor instance, which does not share memory with the current
Tensor.

## Equivalently:
One can also using more intruisive way to get the slice using [] operator.
Expand All @@ -1026,9 +1029,16 @@ namespace cytnx {
#### output>
\verbinclude example/Tensor/get.py.out
*/
Tensor get(const std::vector<cytnx::Accessor> &accessors,
std::vector<cytnx_int64> &removed) const {
Tensor out;
out._impl = this->_impl->get(accessors, removed);
return out;
}
Tensor get(const std::vector<cytnx::Accessor> &accessors) const {
Tensor out;
out._impl = this->_impl->get(accessors);
std::vector<cytnx_int64> removed;
out._impl = this->_impl->get(accessors, removed);
return out;
}

Expand Down
95 changes: 73 additions & 22 deletions include/UniTensor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -684,12 +684,12 @@ namespace cytnx {
if (this->is_diag()) {
cytnx_error_msg(
in.shape() != this->_block.shape(),
"[ERROR][DenseUniTensor] put_block, the input tensor shape does not match.%s", "\n");
"[ERROR][DenseUniTensor][put_block] the input tensor shape does not match.%s", "\n");
this->_block = in.clone();
} else {
cytnx_error_msg(
in.shape() != this->shape(),
"[ERROR][DenseUniTensor] put_block, the input tensor shape does not match.%s", "\n");
"[ERROR][DenseUniTensor][put_block] the input tensor shape does not match.%s", "\n");
this->_block = in.clone();
}
}
Expand All @@ -711,12 +711,12 @@ namespace cytnx {
if (this->is_diag()) {
cytnx_error_msg(
in.shape() != this->_block.shape(),
"[ERROR][DenseUniTensor] put_block, the input tensor shape does not match.%s", "\n");
"[ERROR][DenseUniTensor][put_block] the input tensor shape does not match.%s", "\n");
this->_block = in;
} else {
cytnx_error_msg(
in.shape() != this->shape(),
"[ERROR][DenseUniTensor] put_block, the input tensor shape does not match.%s", "\n");
"[ERROR][DenseUniTensor][put_block] the input tensor shape does not match.%s", "\n");
this->_block = in;
}
}
Expand All @@ -731,16 +731,9 @@ namespace cytnx {
true, "[ERROR][DenseUniTensor] try to put_block using qnum on a non-symmetry UniTensor%s",
"\n");
}
// this will only work on non-symm tensor (DenseUniTensor)
boost::intrusive_ptr<UniTensor_base> get(const std::vector<Accessor> &accessors) {
boost::intrusive_ptr<UniTensor_base> out(new DenseUniTensor());
out->Init_by_Tensor(this->_block.get(accessors), false, 0); // wrapping around.
return out;
}
// this will only work on non-symm tensor (DenseUniTensor)
void set(const std::vector<Accessor> &accessors, const Tensor &rhs) {
this->_block.set(accessors, rhs);
}
// these two methods only work on non-symm tensor (DenseUniTensor)
boost::intrusive_ptr<UniTensor_base> get(const std::vector<Accessor> &accessors);
void set(const std::vector<Accessor> &accessors, const Tensor &rhs);

void reshape_(const std::vector<cytnx_int64> &new_shape, const cytnx_uint64 &rowrank = 0);
boost::intrusive_ptr<UniTensor_base> reshape(const std::vector<cytnx_int64> &new_shape,
Expand Down Expand Up @@ -1700,7 +1693,7 @@ namespace cytnx {
true,
"[ERROR] cannot perform elementwise arithmetic '+' between Scalar and BlockUniTensor.\n %s "
"\n",
"This operation would destroy the block structure. [Suggest] Avoid or use get/set_block(s) "
"This operation would destroy the block structure. [Suggest] Avoid or use get/put_block(s) "
"to do operation on blocks.");
}

Expand All @@ -1713,15 +1706,15 @@ namespace cytnx {
true,
"[ERROR] cannot perform elementwise arithmetic '-' between Scalar and BlockUniTensor.\n %s "
"\n",
"This operation would destroy the block structure. [Suggest] Avoid or use get/set_block(s) "
"This operation would destroy the block structure. [Suggest] Avoid or use get/put_block(s) "
"to do operation on blocks.");
}
void lSub_(const Scalar &lhs) {
cytnx_error_msg(
true,
"[ERROR] cannot perform elementwise arithmetic '-' between Scalar and BlockUniTensor.\n %s "
"\n",
"This operation would destroy the block structure. [Suggest] Avoid or use get/set_block(s) "
"This operation would destroy the block structure. [Suggest] Avoid or use get/put_block(s) "
"to do operation on blocks.");
}

Expand All @@ -1733,7 +1726,7 @@ namespace cytnx {
"[ERROR] cannot perform elementwise arithmetic '/' between Scalar and BlockUniTensor.\n %s "
"\n",
"This operation would cause division by zero on non-block elements. [Suggest] Avoid or use "
"get/set_block(s) to do operation on blocks.");
"get/put_block(s) to do operation on blocks.");
}
void from_(const boost::intrusive_ptr<UniTensor_base> &rhs, const bool &force,
const cytnx_double &tol);
Expand Down Expand Up @@ -2490,7 +2483,7 @@ namespace cytnx {
"BlockFermionicUniTensor.\n %s "
"\n",
"This operation would destroy the block structure. [Suggest] Avoid or use "
"get/set_block(s) to do operation on blocks.");
"get/put_block(s) to do operation on blocks.");
}

void Mul_(const boost::intrusive_ptr<UniTensor_base> &rhs);
Expand All @@ -2503,15 +2496,15 @@ namespace cytnx {
"BlockFermionicUniTensor.\n %s "
"\n",
"This operation would destroy the block structure. [Suggest] Avoid or use "
"get/set_block(s) to do operation on blocks.");
"get/put_block(s) to do operation on blocks.");
}
void lSub_(const Scalar &lhs) {
cytnx_error_msg(true,
"[ERROR] cannot perform elementwise arithmetic '-' between Scalar and "
"BlockFermionicUniTensor.\n %s "
"\n",
"This operation would destroy the block structure. [Suggest] Avoid or use "
"get/set_block(s) to do operation on blocks.");
"get/put_block(s) to do operation on blocks.");
}

void Div_(const boost::intrusive_ptr<UniTensor_base> &rhs);
Expand All @@ -2522,7 +2515,7 @@ namespace cytnx {
"BlockFermionicUniTensor.\n %s "
"\n",
"This operation would cause division by zero on non-block elements. "
"[Suggest] Avoid or use get/set_block(s) to do operation on blocks.");
"[Suggest] Avoid or use get/put_block(s) to do operation on blocks.");
}
void from_(const boost::intrusive_ptr<UniTensor_base> &rhs, const bool &force);

Expand Down Expand Up @@ -4280,11 +4273,69 @@ namespace cytnx {
in.permute_(new_order);
return *this;
}

/**
@brief get elements using Accessor (C++ API) / slices (python API)
@param[in] accessors the Accessor (C++ API) / slices (python API) to get the elements.
@return [UniTensor]
@see Tensor::get, UniTensor::operator[]
@note
1. The return will be a new UniTensor instance, which does not share memory with the current
UniTensor.

2. Equivalently, one can also use the [] operator to access elements.

3. For diagonal UniTensors, the accessor list can have either one element (to address the
diagonal elements), or two elements (in this case, the output will be a non-diagonal
UniTensor).
*/
UniTensor get(const std::vector<Accessor> &accessors) const {
UniTensor out;
out._impl = this->_impl->get(accessors);
return out;
}

/**
@brief get elements using Accessor (C++ API) / slices (python API)
@see get()
*/
UniTensor operator[](const std::vector<cytnx::Accessor> &accessors) const {
UniTensor out;
out._impl = this->_impl->get(accessors);
return out;
}
UniTensor operator[](const std::initializer_list<cytnx::Accessor> &accessors) const {
std::vector<cytnx::Accessor> acc_in = accessors;
return this->get(acc_in);
}
UniTensor operator[](const std::vector<cytnx_int64> &accessors) const {
std::vector<cytnx::Accessor> acc_in;
for (cytnx_int64 i = 0; i < accessors.size(); i++) {
acc_in.push_back(cytnx::Accessor(accessors[i]));
}
return this->get(acc_in);
}
UniTensor operator[](const std::initializer_list<cytnx_int64> &accessors) const {
std::vector<cytnx_int64> acc_in = accessors;
return (*this)[acc_in];
}

/**
@brief set elements using Accessor (C++ API) / slices (python API)
@param[in] accessors the Accessor (C++ API) / slices (python API) to set the elements.
@param[in] rhs the tensor containing the values to set.
@return [UniTensor]
@see Tensor::set, UniTensor::operator[], UniTensor::get
@note
1. The return will be a new UniTensor instance, which does not share memory with the current
UniTensor.

2. Equivalently, one can also use the [] operator to access elements.

3. For diagonal UniTensors, the accessor list can have either one element (to address the
diagonal elements; rhs must be one-dimensional), or two elements (in this case, the
output will be a non-diagonal UniTensor; rhs must be two-dimensional).
*/
UniTensor &set(const std::vector<Accessor> &accessors, const Tensor &rhs) {
this->_impl->set(accessors, rhs);
return *this;
Expand Down
7 changes: 6 additions & 1 deletion include/backend/Tensor_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,12 @@ namespace cytnx {
return this->_storage.at(RealRank);
}

boost::intrusive_ptr<Tensor_impl> get(const std::vector<cytnx::Accessor> &accessors);
boost::intrusive_ptr<Tensor_impl> get(const std::vector<cytnx::Accessor> &accessors,
std::vector<cytnx_int64> &removed);
boost::intrusive_ptr<Tensor_impl> get(const std::vector<cytnx::Accessor> &accessors) {
std::vector<cytnx_int64> removed;
return this->get(accessors, removed);
}
[[deprecated("Use Tensor_impl::get instead")]] boost::intrusive_ptr<Tensor_impl> get_deprecated(
const std::vector<cytnx::Accessor> &accessors);
void set(const std::vector<cytnx::Accessor> &accessors,
Expand Down
Loading
Loading