Program Listing for File subscript.hpp¶
↰ Return to documentation for file (src/subscript.hpp)
/* This file is part of brille.
Copyright © 2020 Greg Tucker <greg.tucker@stfc.ac.uk>
brille is free software: you can redistribute it and/or modify it under the
terms of the GNU Affero General Public License as published by the Free
Software Foundation, either version 3 of the License, or (at your option)
any later version.
brille is distributed in the hope that it will be useful, but
WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY
or FITNESS FOR A PARTICULAR PURPOSE.
See the GNU Affero General Public License for more details.
You should have received a copy of the GNU Affero General Public License
along with brille. If not, see <https://www.gnu.org/licenses/>. */
#ifndef BRILLE_SUBSCRIPT_HPP_
#define BRILLE_SUBSCRIPT_HPP_
#include <algorithm>
#include <array>
#include <vector>
#include <cassert>
#include <iostream>
#include <tuple>
#include <numeric>
namespace brille {
template<class I1, class I2>
std::enable_if_t<std::is_unsigned<I1>::value, bool>
is_fixed(I1 i, I2 n){
return i<n;
}
template<class I1, class I2>
std::enable_if_t<std::is_signed<I1>::value, bool>
is_fixed(I1 i, I2 n){
if (i<0) return false;
return static_cast<I2>(i)<n;
}
template<class T> class SubIt {
public:
typedef std::vector<T> holder;
holder _shape;
holder _inpt;
holder _sub;
std::vector<bool> _fixed;
size_t _first;
private:
// void find_first() {
// // find the first non-fixed index (or the length of the _fixed vector)
// auto fitr = std::find(_fixed.begin(), _fixed.end(), false);
// if (fitr == _fixed.end())
// throw std::runtime_error("The input subscripts have fixed all dimensions!");
// _first = std::distance(_fixed.begin(), fitr);
// }
void find_first() {
// find the first non-fixed index or the length of the _fixed vector
_first = _fixed.size();
for (size_t i=_first; i-- > 0;) if (!_fixed[i]) _first=i;
if (_first == _fixed.size())
throw std::runtime_error("The input subscripts have fixed all dimensions!");
}
public:
explicit SubIt()
: _shape({0}), _inpt({0}), _sub({0}), _fixed({false}), _first(0)
{}
SubIt(const holder& _sh)
: _shape(_sh), _first(0)
{
size_t n = _shape.size();
_inpt = holder(n, T(0));
_sub = holder(n, T(0));
_fixed = std::vector<bool>(n, false);
}
SubIt(const holder& _sh, const holder& _in)
: _shape(_sh), _inpt(_in)
{
assert(_shape.size() == _inpt.size());
size_t n = _shape.size();
_sub = holder(n, T(0));
_fixed = std::vector<bool>(n, false);
for (size_t i=0; i<n; ++i){
_fixed[i] = is_fixed(_inpt[i], _shape[i]);
_sub[i] = _fixed[i] ? _inpt[i] : T(0);
}
this->find_first();
}
SubIt(const holder& _sh, const holder& _in, const holder& _s, const std::vector<bool>& _f)
: _shape(_sh), _inpt(_in), _sub(_s), _fixed(_f)
{
this->find_first();
}
SubIt(const SubIt<T>& o)
: _shape(o._shape), _inpt(o._inpt), _sub(o._sub), _fixed(o._fixed), _first(o._first)
{}
SubIt(const SubIt<T>* o)
: _shape(o->_shape), _inpt(o->_inpt), _sub(o->_sub), _fixed(o->_fixed), _first(o->first)
{}
SubIt& operator=(const SubIt<T>& o){
_shape = o._shape;
_inpt = o._inpt;
_sub = o._sub;
_fixed = o._fixed;
_first = o._first;
return *this;
}
const holder& shape() const {return _shape;}
size_t ndim() const {return _shape.size();}
SubIt& operator++(){
size_t n = this->ndim();
for (size_t dim=n; dim-->0; ) if (!_fixed[dim]) {
if (dim > _first && _sub[dim]+1 == _shape[dim]){
_sub[dim] = 0u;
} else {
++_sub[dim];
break;
}
}
return *this;
}
bool operator==(const SubIt<T>& other) const {
size_t n = this->ndim();
if (other.ndim() != n) return false;
bool equal{true};
for (size_t i=0; i<n; ++i) equal &= _sub[i] == other._sub[i];
return equal;
}
bool operator!=(const SubIt<T>& other) const { return !(*this == other); }
const holder& operator*() const {return _sub;}
const holder* operator->() const {return &_sub;}
holder& operator*() {return _sub;}
holder* operator->() {return &_sub;}
SubIt<T> begin() const {
// return SubIt<T>(_shape, _inpt, _sub, _fixed);
size_t n = this->ndim();
holder sub(n,T(0));
for (size_t i=0; i<sub.size(); ++i) if (_fixed[i]) sub[i] = _inpt[i];
return SubIt<T>(_shape, _inpt, sub, _fixed);
}
SubIt<T> end() const {
size_t n = this->ndim();
holder val(n, T(0));
for (size_t i=0; i<n; ++i) if (_fixed[i]) val[i] = _sub[i];
if (_first < n) val[_first] = _shape[_first];
return SubIt<T>(_shape, _inpt, val, _fixed);
}
};
template<class T> class SubIt2 {
public:
typedef std::array<T,2> holder;
holder _shape;
holder _inpt;
holder _sub;
std::array<bool,2> _fixed;
size_t _first;
private:
void find_first() {
// find the first non-fixed index or the length of the _fixed vector
_first = _fixed.size();
for (size_t i=_first; i-- > 0;) if (!_fixed[i]) _first=i;
if (_first == _fixed.size())
throw std::runtime_error("The input subscripts have fixed all dimensions!");
}
public:
explicit SubIt2()
: _shape({0,0}), _inpt({0,0}), _sub({0,0}), _fixed({false,false}), _first(0)
{}
SubIt2(const holder& _sh)
: _shape(_sh), _first(0)
{
_inpt = holder({T(0), T(0)});
_sub = holder({T(0), T(0)});
_fixed = std::array<bool,2>({false, false});
}
SubIt2(const holder& _sh, const holder& _in)
: _shape(_sh), _inpt(_in)
{
_sub = holder({T(0), T(0)});
_fixed = std::array<bool,2>({false, false});
for (size_t i=0; i<2; ++i){
_fixed[i] = is_fixed(_inpt[i], _shape[i]);
_sub[i] = _fixed[i] ? _inpt[i] : T(0);
}
this->find_first();
}
SubIt2(const holder& _sh, const holder& _in, const holder& _s, const std::array<bool,2>& _f)
: _shape(_sh), _inpt(_in), _sub(_s), _fixed(_f)
{
this->find_first();
}
SubIt2(const SubIt2<T>& o)
: _shape(o._shape), _inpt(o._inpt), _sub(o._sub), _fixed(o._fixed), _first(o._first)
{}
SubIt2(const SubIt2<T>* o)
: _shape(o->_shape), _inpt(o->_inpt), _sub(o->_sub), _fixed(o->_fixed), _first(o->first)
{}
SubIt2& operator=(const SubIt2<T>& o){
_shape = o._shape;
_inpt = o._inpt;
_sub = o._sub;
_fixed = o._fixed;
_first = o._first;
return *this;
}
const holder& shape() const {return _shape;}
size_t ndim() const {return 2;}
SubIt2& operator++(){
size_t n = 2;
for (size_t dim=n; dim-->0; ) if (!_fixed[dim]) {
if (dim > _first && _sub[dim]+1 == _shape[dim]){
_sub[dim] = 0u;
} else {
++_sub[dim];
break;
}
}
return *this;
}
bool operator==(const SubIt2<T>& other) const {;
return _sub[0] == other._sub[0] && _sub[1] == other._sub[1];
}
bool operator!=(const SubIt2<T>& other) const { return !(*this == other); }
const holder& operator*() const {return _sub;}
const holder* operator->() const {return &_sub;}
holder& operator*() {return _sub;}
holder* operator->() {return &_sub;}
// std::tuple<T,T> operator*() const {return std::make_tuple(_sub[0], _sub[1])};
SubIt2<T> begin() const {
holder sub({T(0), T(0)});
if (_fixed[0]) sub[0] = _inpt[0];
if (_fixed[1]) sub[1] = _inpt[1];
return SubIt2<T>(_shape, _inpt, sub, _fixed);
}
SubIt2<T> end() const {
holder val({T(0), T(0)});
if (_fixed[0]) val[0] = _sub[0];
if (_fixed[1]) val[1] = _sub[1];
if (_first < 2) val[_first] = _shape[_first];
return SubIt2<T>(_shape, _inpt, val, _fixed);
}
};
template<class T> class BroadcastIt{
public:
typedef std::vector<T> holder;
private:
holder _shape0;
holder _shape1;
holder _shapeO;
holder _sub0;
holder _sub1;
holder _subO;
public:
//
BroadcastIt(const holder& a, const holder& b)
: _shape0(a), _shape1(b), _shapeO(a.size(),0), _sub0(a.size(),0), _sub1(a.size(),0), _subO(a.size(),0)
{
assert(_shape0.size() == _shape1.size());
size_t nd = _shape0.size();
for (size_t i=0; i<nd; ++i)
if (_shape0[i]!=_shape1[i] && _shape0[i]!=1 && _shape1[i]!=1){
std::string msg = "Can not broadcast { ";
for (auto x: _shape0) msg += std::to_string(x) + " ";
msg += "} and { ";
for (auto x: _shape1) msg += std::to_string(x) + " ";
msg += "} to a common shape";
throw std::runtime_error(msg);
} else {
_shapeO[i] = _shape0[i] < _shape1[i] ? _shape1[i] : _shape0[i];
}
}
BroadcastIt(const holder& s0, const holder& s1, const holder & sO, const holder& i0, const holder& i1, const holder& iO)
: _shape0(s0), _shape1(s1), _shapeO(sO), _sub0(i0), _sub1(i1), _subO(iO)
{
}
holder shape() const {return _shapeO;}
size_t ndim() const {return _shapeO.size();}
//const SubIt<T>& itr() const {return _itr;}
BroadcastIt<T>& operator++(){
size_t n = this->ndim();
for (size_t dim=n; dim-->0; ){
if (dim > 0 && _subO[dim]+1 == _shapeO[dim]){
_sub1[dim] = _sub0[dim] = _subO[dim] = 0u;
} else {
++_subO[dim];
if (_shape0[dim] > 1) _sub0[dim] = _subO[dim];
if (_shape1[dim] > 1) _sub1[dim] = _subO[dim];
break;
}
}
return *this;
}
bool operator==(const BroadcastIt<T>& other) const {
size_t n = this->ndim();
if (other.ndim() != n) return false;
bool equal{true};
const holder& oO{other.outer()};
for (size_t i=0; i<n; ++i) equal &= _subO[i] == oO[i];
return equal;
}
bool operator!=(const BroadcastIt<T>& other) const {return !(*this==other);}
// const subs_t& operator*() const {return subs;}
// const subs_t* operator->() const {return &subs;}
// subs_t& operator*() {return subs;}
// subs_t* operator->() {return &subs;}
std::tuple<holder,holder,holder> operator*() const {return triple_subscripts();}
BroadcastIt<T> begin() const {
size_t n = this->ndim();
holder s0(n,0), s1(n,0), sO(n,0);
return BroadcastIt<T>(_shape0, _shape1, _shapeO, s0, s1, sO);
}
BroadcastIt<T> end() const {
size_t n = this->ndim();
holder s0(n,0), s1(n,0), sO(n,0);
sO[0] = _shapeO[0];
if (_shape0[0] > 1) s0[0] = sO[0]; // not used in ==
if (_shape1[0] > 1) s1[0] = sO[0]; // not used in ==
return BroadcastIt<T>(_shape0, _shape1, _shapeO, s0, s1, sO);
}
private:
// std::tuple<holder,holder,holder> triple_subscripts() const {
// holder o{*_itr};
// holder a(o), b(o);
// for (size_t i=0; i<o.size(); ++i) if (o[i]>0) {
// if (1==_shape0[i]) a[i] = 0;
// if (1==_shape1[i]) b[i] = 0;
// }
// return std::make_tuple(o,a,b);
// }
std::tuple<holder,holder,holder> triple_subscripts() const {
return std::make_tuple(_subO, _sub0, _sub1);
}
protected:
const holder& outer() const {return _subO;}
};
template<class T> class BroadcastIt2{
public:
typedef std::array<T,2> holder;
private:
holder _shape0;
holder _shape1;
holder _shapeO;
holder _sub0;
holder _sub1;
holder _subO;
public:
//
BroadcastIt2(const holder& a, const holder& b)
: _shape0(a), _shape1(b), _shapeO({0,0}), _sub0({0,0}), _sub1({0,0}), _subO({0,0})
{
size_t nd = 2;
for (size_t i=0; i<nd; ++i)
if (_shape0[i]!=_shape1[i] && _shape0[i]!=1 && _shape1[i]!=1){
std::string msg = "Can not broadcast { ";
for (auto x: _shape0) msg += std::to_string(x) + " ";
msg += "} and { ";
for (auto x: _shape1) msg += std::to_string(x) + " ";
msg += "} to a common shape";
throw std::runtime_error(msg);
} else {
_shapeO[i] = _shape0[i] < _shape1[i] ? _shape1[i] : _shape0[i];
}
}
BroadcastIt2(const holder& s0, const holder& s1, const holder & sO, const holder& i0, const holder& i1, const holder& iO)
: _shape0(s0), _shape1(s1), _shapeO(sO), _sub0(i0), _sub1(i1), _subO(iO)
{
}
holder shape() const {return _shapeO;}
size_t ndim() const {return 2;}
//const SubIt<T>& itr() const {return _itr;}
BroadcastIt2<T>& operator++(){
size_t n = 2;
for (size_t dim=n; dim-->0; ){
if (dim > 0 && _subO[dim]+1 == _shapeO[dim]){
_sub1[dim] = _sub0[dim] = _subO[dim] = 0u;
} else {
++_subO[dim];
if (_shape0[dim] > 1) _sub0[dim] = _subO[dim];
if (_shape1[dim] > 1) _sub1[dim] = _subO[dim];
break;
}
}
return *this;
}
bool operator==(const BroadcastIt2<T>& other) const {
const holder& oO{other.outer()};
return (_subO[0]==oO[0] && _subO[1]==oO[1]);
}
bool operator!=(const BroadcastIt2<T>& other) const {return !(*this==other);}
std::tuple<holder,holder,holder> operator*() const {
return std::make_tuple(_subO, _sub0, _sub1);
}
BroadcastIt2<T> begin() const {
holder s0({0,0}), s1({0,0}), sO({0,0});
return BroadcastIt2<T>(_shape0, _shape1, _shapeO, s0, s1, sO);
}
BroadcastIt2<T> end() const {
holder s0({0,0}), s1({0,0}), sO({0,0});
sO[0] = _shapeO[0];
if (_shape0[0] > 1) s0[0] = sO[0]; // not used in ==
if (_shape1[0] > 1) s1[0] = sO[0]; // not used in ==
return BroadcastIt2<T>(_shape0, _shape1, _shapeO, s0, s1, sO);
}
protected:
const holder& outer() const {return _subO;}
};
template <class I>
std::vector<I> lin2sub(I l, const std::vector<I>& stride){
std::vector<I> sub;
size_t ndim = stride.size();
if (1 == ndim) sub.push_back(l);
else if (1 < ndim) {
sub.resize(ndim);
if (stride[ndim-1] > stride[0])
for (I i=ndim-1; i--; ){
sub[i] = l/stride[i];
l -= sub[i]*stride[i];
}
else
for (I i=0; i<ndim; ++i){
sub[i] = l/stride[i];
l -= sub[i]*stride[i];
}
}
return sub;
}
template <class I>
std::array<I,2> lin2sub(I l, const std::array<I,2>& stride){
std::array<I,2> sub;
if (stride[1] > stride[0]){
sub[1] = l/stride[1];
sub[0] = (l-sub[1]*stride[1])/stride[0];
} else {
sub[0] = l/stride[0];
sub[1] = (l-sub[0]*stride[0])/stride[1];
}
return sub;
}
// template <class I>
// I sub2lin(const std::vector<I>& sub, const std::vector<I>& stride){
// assert(sub.size() == stride.size());
// #if defined(__GNUC__) && (__GNUC__ < 9 || (__GNUC__ == 9 && __GNUC_MINOR__ <= 2))
// // serial inner_product
// return std::inner_product(sub.begin(), sub.end(), stride.begin(), I(0));
// #else
// // parallelized inner_product
// return std::transform_reduce(sub.begin(), sub.end(), stride.begin(), I(0));
// #endif
// }
template <class I>
I sub2lin(const std::vector<I>& sub, const std::vector<I>& str){
I lin{0};
for (size_t i=0; i<sub.size(); ++i) lin += sub[i]*str[i];
return lin;
}
template <class I>
I sub2lin(const std::array<I,2>& sub, const std::array<I,2>& str){
return sub[0]*str[0] + sub[1]*str[1];
}
template <class I>
I sub2lin(const I s0, const I s1, const std::array<I,2>& str){
return s0*str[0] + s1*str[1];
}
template <class I>
I offset_sub2lin(const std::vector<I>& off, const std::vector<I>& sub, const std::vector<I>& str){
I lin{0};
for (size_t i=0; i<sub.size(); ++i) lin += (off[i]+sub[i])*str[i];
return lin;
}
template <class I>
I offset_sub2lin(const std::array<I,2>& off, const std::array<I,2>& sub, const std::array<I,2>& str){
return (off[0]+sub[0])*str[0] + (off[1]+sub[1])*str[1];
}
template <class I>
I offset_sub2lin(const std::array<I,2>& off, const I s0, const I s1, const std::array<I,2>& str){
return (off[0]+s0)*str[0] + (off[1]+s1)*str[1];
}
} // end namespace brille
#endif