.. _program_listing_file_src_subscript.hpp: Program Listing for File subscript.hpp ====================================== |exhale_lsh| :ref:`Return to documentation for file ` (``src/subscript.hpp``) .. |exhale_lsh| unicode:: U+021B0 .. UPWARDS ARROW WITH TIP LEFTWARDS .. code-block:: cpp /* This file is part of brille. Copyright © 2020 Greg Tucker brille is free software: you can redistribute it and/or modify it under the terms of the GNU Affero General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your option) any later version. brille is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU Affero General Public License for more details. You should have received a copy of the GNU Affero General Public License along with brille. If not, see . */ #ifndef BRILLE_SUBSCRIPT_HPP_ #define BRILLE_SUBSCRIPT_HPP_ #include #include #include #include #include #include #include namespace brille { template std::enable_if_t::value, bool> is_fixed(I1 i, I2 n){ return i std::enable_if_t::value, bool> is_fixed(I1 i, I2 n){ if (i<0) return false; return static_cast(i) class SubIt { public: typedef std::vector holder; holder _shape; holder _inpt; holder _sub; std::vector _fixed; size_t _first; private: // void find_first() { // // find the first non-fixed index (or the length of the _fixed vector) // auto fitr = std::find(_fixed.begin(), _fixed.end(), false); // if (fitr == _fixed.end()) // throw std::runtime_error("The input subscripts have fixed all dimensions!"); // _first = std::distance(_fixed.begin(), fitr); // } void find_first() { // find the first non-fixed index or the length of the _fixed vector _first = _fixed.size(); for (size_t i=_first; i-- > 0;) if (!_fixed[i]) _first=i; if (_first == _fixed.size()) throw std::runtime_error("The input subscripts have fixed all dimensions!"); } public: explicit SubIt() : _shape({0}), _inpt({0}), _sub({0}), _fixed({false}), _first(0) {} SubIt(const holder& _sh) : _shape(_sh), _first(0) { size_t n = _shape.size(); _inpt = holder(n, T(0)); _sub = holder(n, T(0)); _fixed = std::vector(n, false); } SubIt(const holder& _sh, const holder& _in) : _shape(_sh), _inpt(_in) { assert(_shape.size() == _inpt.size()); size_t n = _shape.size(); _sub = holder(n, T(0)); _fixed = std::vector(n, false); for (size_t i=0; ifind_first(); } SubIt(const holder& _sh, const holder& _in, const holder& _s, const std::vector& _f) : _shape(_sh), _inpt(_in), _sub(_s), _fixed(_f) { this->find_first(); } SubIt(const SubIt& o) : _shape(o._shape), _inpt(o._inpt), _sub(o._sub), _fixed(o._fixed), _first(o._first) {} SubIt(const SubIt* o) : _shape(o->_shape), _inpt(o->_inpt), _sub(o->_sub), _fixed(o->_fixed), _first(o->first) {} SubIt& operator=(const SubIt& o){ _shape = o._shape; _inpt = o._inpt; _sub = o._sub; _fixed = o._fixed; _first = o._first; return *this; } const holder& shape() const {return _shape;} size_t ndim() const {return _shape.size();} SubIt& operator++(){ size_t n = this->ndim(); for (size_t dim=n; dim-->0; ) if (!_fixed[dim]) { if (dim > _first && _sub[dim]+1 == _shape[dim]){ _sub[dim] = 0u; } else { ++_sub[dim]; break; } } return *this; } bool operator==(const SubIt& other) const { size_t n = this->ndim(); if (other.ndim() != n) return false; bool equal{true}; for (size_t i=0; i& other) const { return !(*this == other); } const holder& operator*() const {return _sub;} const holder* operator->() const {return &_sub;} holder& operator*() {return _sub;} holder* operator->() {return &_sub;} SubIt begin() const { // return SubIt(_shape, _inpt, _sub, _fixed); size_t n = this->ndim(); holder sub(n,T(0)); for (size_t i=0; i(_shape, _inpt, sub, _fixed); } SubIt end() const { size_t n = this->ndim(); holder val(n, T(0)); for (size_t i=0; i(_shape, _inpt, val, _fixed); } }; template class SubIt2 { public: typedef std::array holder; holder _shape; holder _inpt; holder _sub; std::array _fixed; size_t _first; private: void find_first() { // find the first non-fixed index or the length of the _fixed vector _first = _fixed.size(); for (size_t i=_first; i-- > 0;) if (!_fixed[i]) _first=i; if (_first == _fixed.size()) throw std::runtime_error("The input subscripts have fixed all dimensions!"); } public: explicit SubIt2() : _shape({0,0}), _inpt({0,0}), _sub({0,0}), _fixed({false,false}), _first(0) {} SubIt2(const holder& _sh) : _shape(_sh), _first(0) { _inpt = holder({T(0), T(0)}); _sub = holder({T(0), T(0)}); _fixed = std::array({false, false}); } SubIt2(const holder& _sh, const holder& _in) : _shape(_sh), _inpt(_in) { _sub = holder({T(0), T(0)}); _fixed = std::array({false, false}); for (size_t i=0; i<2; ++i){ _fixed[i] = is_fixed(_inpt[i], _shape[i]); _sub[i] = _fixed[i] ? _inpt[i] : T(0); } this->find_first(); } SubIt2(const holder& _sh, const holder& _in, const holder& _s, const std::array& _f) : _shape(_sh), _inpt(_in), _sub(_s), _fixed(_f) { this->find_first(); } SubIt2(const SubIt2& o) : _shape(o._shape), _inpt(o._inpt), _sub(o._sub), _fixed(o._fixed), _first(o._first) {} SubIt2(const SubIt2* o) : _shape(o->_shape), _inpt(o->_inpt), _sub(o->_sub), _fixed(o->_fixed), _first(o->first) {} SubIt2& operator=(const SubIt2& o){ _shape = o._shape; _inpt = o._inpt; _sub = o._sub; _fixed = o._fixed; _first = o._first; return *this; } const holder& shape() const {return _shape;} size_t ndim() const {return 2;} SubIt2& operator++(){ size_t n = 2; for (size_t dim=n; dim-->0; ) if (!_fixed[dim]) { if (dim > _first && _sub[dim]+1 == _shape[dim]){ _sub[dim] = 0u; } else { ++_sub[dim]; break; } } return *this; } bool operator==(const SubIt2& other) const {; return _sub[0] == other._sub[0] && _sub[1] == other._sub[1]; } bool operator!=(const SubIt2& other) const { return !(*this == other); } const holder& operator*() const {return _sub;} const holder* operator->() const {return &_sub;} holder& operator*() {return _sub;} holder* operator->() {return &_sub;} // std::tuple operator*() const {return std::make_tuple(_sub[0], _sub[1])}; SubIt2 begin() const { holder sub({T(0), T(0)}); if (_fixed[0]) sub[0] = _inpt[0]; if (_fixed[1]) sub[1] = _inpt[1]; return SubIt2(_shape, _inpt, sub, _fixed); } SubIt2 end() const { holder val({T(0), T(0)}); if (_fixed[0]) val[0] = _sub[0]; if (_fixed[1]) val[1] = _sub[1]; if (_first < 2) val[_first] = _shape[_first]; return SubIt2(_shape, _inpt, val, _fixed); } }; template class BroadcastIt{ public: typedef std::vector holder; private: holder _shape0; holder _shape1; holder _shapeO; holder _sub0; holder _sub1; holder _subO; public: // BroadcastIt(const holder& a, const holder& b) : _shape0(a), _shape1(b), _shapeO(a.size(),0), _sub0(a.size(),0), _sub1(a.size(),0), _subO(a.size(),0) { assert(_shape0.size() == _shape1.size()); size_t nd = _shape0.size(); for (size_t i=0; i& itr() const {return _itr;} BroadcastIt& operator++(){ size_t n = this->ndim(); for (size_t dim=n; dim-->0; ){ if (dim > 0 && _subO[dim]+1 == _shapeO[dim]){ _sub1[dim] = _sub0[dim] = _subO[dim] = 0u; } else { ++_subO[dim]; if (_shape0[dim] > 1) _sub0[dim] = _subO[dim]; if (_shape1[dim] > 1) _sub1[dim] = _subO[dim]; break; } } return *this; } bool operator==(const BroadcastIt& other) const { size_t n = this->ndim(); if (other.ndim() != n) return false; bool equal{true}; const holder& oO{other.outer()}; for (size_t i=0; i& other) const {return !(*this==other);} // const subs_t& operator*() const {return subs;} // const subs_t* operator->() const {return &subs;} // subs_t& operator*() {return subs;} // subs_t* operator->() {return &subs;} std::tuple operator*() const {return triple_subscripts();} BroadcastIt begin() const { size_t n = this->ndim(); holder s0(n,0), s1(n,0), sO(n,0); return BroadcastIt(_shape0, _shape1, _shapeO, s0, s1, sO); } BroadcastIt end() const { size_t n = this->ndim(); holder s0(n,0), s1(n,0), sO(n,0); sO[0] = _shapeO[0]; if (_shape0[0] > 1) s0[0] = sO[0]; // not used in == if (_shape1[0] > 1) s1[0] = sO[0]; // not used in == return BroadcastIt(_shape0, _shape1, _shapeO, s0, s1, sO); } private: // std::tuple triple_subscripts() const { // holder o{*_itr}; // holder a(o), b(o); // for (size_t i=0; i0) { // if (1==_shape0[i]) a[i] = 0; // if (1==_shape1[i]) b[i] = 0; // } // return std::make_tuple(o,a,b); // } std::tuple triple_subscripts() const { return std::make_tuple(_subO, _sub0, _sub1); } protected: const holder& outer() const {return _subO;} }; template class BroadcastIt2{ public: typedef std::array holder; private: holder _shape0; holder _shape1; holder _shapeO; holder _sub0; holder _sub1; holder _subO; public: // BroadcastIt2(const holder& a, const holder& b) : _shape0(a), _shape1(b), _shapeO({0,0}), _sub0({0,0}), _sub1({0,0}), _subO({0,0}) { size_t nd = 2; for (size_t i=0; i& itr() const {return _itr;} BroadcastIt2& operator++(){ size_t n = 2; for (size_t dim=n; dim-->0; ){ if (dim > 0 && _subO[dim]+1 == _shapeO[dim]){ _sub1[dim] = _sub0[dim] = _subO[dim] = 0u; } else { ++_subO[dim]; if (_shape0[dim] > 1) _sub0[dim] = _subO[dim]; if (_shape1[dim] > 1) _sub1[dim] = _subO[dim]; break; } } return *this; } bool operator==(const BroadcastIt2& other) const { const holder& oO{other.outer()}; return (_subO[0]==oO[0] && _subO[1]==oO[1]); } bool operator!=(const BroadcastIt2& other) const {return !(*this==other);} std::tuple operator*() const { return std::make_tuple(_subO, _sub0, _sub1); } BroadcastIt2 begin() const { holder s0({0,0}), s1({0,0}), sO({0,0}); return BroadcastIt2(_shape0, _shape1, _shapeO, s0, s1, sO); } BroadcastIt2 end() const { holder s0({0,0}), s1({0,0}), sO({0,0}); sO[0] = _shapeO[0]; if (_shape0[0] > 1) s0[0] = sO[0]; // not used in == if (_shape1[0] > 1) s1[0] = sO[0]; // not used in == return BroadcastIt2(_shape0, _shape1, _shapeO, s0, s1, sO); } protected: const holder& outer() const {return _subO;} }; template std::vector lin2sub(I l, const std::vector& stride){ std::vector sub; size_t ndim = stride.size(); if (1 == ndim) sub.push_back(l); else if (1 < ndim) { sub.resize(ndim); if (stride[ndim-1] > stride[0]) for (I i=ndim-1; i--; ){ sub[i] = l/stride[i]; l -= sub[i]*stride[i]; } else for (I i=0; i std::array lin2sub(I l, const std::array& stride){ std::array sub; if (stride[1] > stride[0]){ sub[1] = l/stride[1]; sub[0] = (l-sub[1]*stride[1])/stride[0]; } else { sub[0] = l/stride[0]; sub[1] = (l-sub[0]*stride[0])/stride[1]; } return sub; } // template // I sub2lin(const std::vector& sub, const std::vector& stride){ // assert(sub.size() == stride.size()); // #if defined(__GNUC__) && (__GNUC__ < 9 || (__GNUC__ == 9 && __GNUC_MINOR__ <= 2)) // // serial inner_product // return std::inner_product(sub.begin(), sub.end(), stride.begin(), I(0)); // #else // // parallelized inner_product // return std::transform_reduce(sub.begin(), sub.end(), stride.begin(), I(0)); // #endif // } template I sub2lin(const std::vector& sub, const std::vector& str){ I lin{0}; for (size_t i=0; i I sub2lin(const std::array& sub, const std::array& str){ return sub[0]*str[0] + sub[1]*str[1]; } template I sub2lin(const I s0, const I s1, const std::array& str){ return s0*str[0] + s1*str[1]; } template I offset_sub2lin(const std::vector& off, const std::vector& sub, const std::vector& str){ I lin{0}; for (size_t i=0; i I offset_sub2lin(const std::array& off, const std::array& sub, const std::array& str){ return (off[0]+sub[0])*str[0] + (off[1]+sub[1])*str[1]; } template I offset_sub2lin(const std::array& off, const I s0, const I s1, const std::array& str){ return (off[0]+s0)*str[0] + (off[1]+s1)*str[1]; } } // end namespace brille #endif