-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathTensor.h
More file actions
58 lines (46 loc) · 1.31 KB
/
Tensor.h
File metadata and controls
58 lines (46 loc) · 1.31 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
#pragma once
#include <vector>
#include <string>
#include <unordered_map>
#include <iostream>
#include <memory>
#include <queue>
#include <stdexcept>
#include "optypes.h"
enum class DType {
FLOAT,
INT32,
INT64
};
template <typename T>
class Tensor {
public:
Tensor() : dtype_(DType::FLOAT) {}
Tensor(std::vector<size_t> shape, DType dtype)
: shape_(std::move(shape)), dtype_(dtype), data_(size()) {}
Tensor(std::vector<size_t> shape, DType dtype, std::vector<T> data)
: shape_(std::move(shape)), dtype_(dtype), data_(std::move(data)) {
if (data_.size() != size()) {
throw std::runtime_error("Tensor: data size does not match shape");
}
}
Tensor(std::vector<size_t> shape, DType dtype, const T* ptr)
: shape_(std::move(shape)), dtype_(dtype), data_(size()) {
std::copy(ptr, ptr + data_.size(), data_.begin());
}
const std::vector<size_t>& shape() const { return shape_; }
DType dtype() const { return dtype_; }
T* data() { return data_.data(); }
const T* data() const { return data_.data(); }
size_t size() const {
size_t s = 1;
for (size_t dim : shape_) s *= dim;
return s;
}
std::vector<T>& vec() { return data_; }
const std::vector<T>& vec() const { return data_; }
private:
std::vector<size_t> shape_;
std::vector<T> data_;
DType dtype_;
};