Program Listing for File interpolator2_cost.tpp

Return to documentation for file (src/interpolator2_cost.tpp)

/* This file is part of brille.

Copyright © 2020 Greg Tucker <greg.tucker@stfc.ac.uk>

brille is free software: you can redistribute it and/or modify it under the
terms of the GNU Affero General Public License as published by the Free
Software Foundation, either version 3 of the License, or (at your option)
any later version.

brille is distributed in the hope that it will be useful, but
WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY
or FITNESS FOR A PARTICULAR PURPOSE.
See the GNU Affero General Public License for more details.

You should have received a copy of the GNU Affero General Public License
along with brille. If not, see <https://www.gnu.org/licenses/>.            */

template<class T>
template<typename S>
void
Interpolator<T>::add_cost(const ind_t i0, const ind_t i1, std::vector<S>& cost, const bool arbitrary_phase_allowed) const {
  const T * x0 = data_.ptr(i0), * x1 = data_.ptr(i1);
  S s_cost{0}, v_cost{0}, m_cost{0};
  auto e_ = _elements;
  const ind_t s_{this->branch_span()}, b_{this->branches()}, mo_{e_[0]+e_[1]};
  if (arbitrary_phase_allowed){ // if the _vectorfun uses the Hermitian angle, e^iθ *never* matters.
    auto phased = std::unique_ptr<T[]>(new T[s_]);
    for (ind_t i=0; i<b_; ++i){
      const T * x0i = x0+i*s_;
      for (ind_t j=0; j<b_; ++j){
        brille::utils::inplace_antiphase(s_, x0i, x1+j*s_, phased.get());
        if (e_[0]) s_cost = this->_scalarfun(e_[0], x0i, phased.get());
        if (e_[1]) v_cost = this->_vectorfun(e_[1], x0i+e_[0], phased.get()+e_[0]);
        if (e_[2]){
          m_cost = 0;
          for (ind_t m=0; m<e_[2]/9; ++m)
            m_cost += brille::utils::frobenius_distance(3u, x0i+mo_+9u*m, phased.get()+mo_+9u*m);
        }
        cost[i*b_+j] += _costmult[0]*s_cost + _costmult[1]*v_cost + _costmult[2]*m_cost;
      }
    }
  } else {
    for (ind_t i=0; i<b_; ++i){
      const T * x0i = x0+i*s_;
      for (ind_t j=0; j<b_; ++j){
        const T * x1j = x1+j*s_;
        if (e_[0]) s_cost = this->_scalarfun(e_[0], x0i, x1j);
        if (e_[1]) v_cost = this->_vectorfun(e_[1], x0i+e_[0], x1j+e_[0]);
        if (e_[2]){
          m_cost = 0;
          for (ind_t m=0; m<e_[2]/9; ++m)
            m_cost += brille::utils::frobenius_distance(3u, x0i+mo_+9u*m, x1j+mo_+9u*m);
        }
        cost[i*b_+j] += _costmult[0]*s_cost + _costmult[1]*v_cost + _costmult[2]*m_cost;
      }
    }
  }
}