IOLink 1.16.1
Loading...
Searching...
No Matches
TensorView.h
1#pragma once
2
3#include <iolink/DataType.h>
4#include <iolink/FlagSet.h>
5#include <iolink/IOLinkAPI.h>
6#include <iolink/Indexer.h>
7#include <iolink/RegionX.h>
8#include <iolink/VectorX.h>
9#include <iolink/view/View.h>
10
11namespace iolink
12{
13
15{
16 // Flags
17 READ = 0x1,
18 WRITE = 0x2,
19 MEMORY_ACCESS = 0x4,
20 RESHAPE = 0x8,
21
22 // Combinations
23 READ_WRITE = 0x3,
24};
25
26// Declare the flagset for tensor capability does not need to be instanciated
27// by the library user
28extern template class IOLINK_API_IMPORT FlagSet<TensorCapability>;
29
38
39// Define bitwise operators for TensorCapability
40IOLINK_DEFINE_ENUM_BITWISE_OPERATORS(TensorCapability)
41
42
50class IOLINK_API TensorView : public View
51{
52public:
56 virtual ~TensorView() = default;
57
66 virtual TensorCapabilitySet capabilities() const = 0;
67
74 inline bool support(TensorCapabilitySet capabilities) const { return this->capabilities().has(capabilities); }
75
82 virtual const VectorXu64& shape() const = 0;
83
90 inline size_t dimensionCount() const { return this->shape().size(); }
91
98 virtual DataType dtype() const = 0;
99
103 std::string toString() const;
104
105 // ==== //
106 // READ //
107 // ==== //
108
118 virtual void read(const RegionXu64& region, void* dst) const;
119
132 template <typename T>
133 inline T at(const VectorXu64& index) const
134 {
135 T value;
136 const VectorXu64 size = VectorXu64::createUniform(this->dimensionCount(), 1);
137 const RegionXu64 region(index, size);
138 this->read(region, &value);
139 return value;
140 }
141
142 // ===== //
143 // WRITE //
144 // ===== //
145
155 virtual void write(const RegionXu64& region, const void* src);
156
169 template <typename T>
170 inline void setAt(const VectorXu64& index, T value)
171 {
172 const VectorXu64 size = VectorXu64::createUniform(this->dimensionCount(), 1);
173 const RegionXu64 region(index, size);
174 this->write(region, &value);
175 }
176
177 // ======= //
178 // RESHAPE //
179 // ======= //
180
181 virtual void reshape(const VectorXu64& shape, DataType dtype);
182
183 // ============= //
184 // MEMORY_ACCESS //
185 // ============= //
186
192 virtual const Indexer& indexer() const;
193
199 virtual size_t bufferSize() const;
200
207 virtual void* buffer();
208
214 virtual const void* buffer() const;
215};
216
217} // end namespace iolink