Program Listing for File interpolator2.hpp

Return to documentation for file (src/interpolator2.hpp)

/* This file is part of brille.

Copyright © 2019,2020 Greg Tucker <greg.tucker@stfc.ac.uk>

brille is free software: you can redistribute it and/or modify it under the
terms of the GNU Affero General Public License as published by the Free
Software Foundation, either version 3 of the License, or (at your option)
any later version.

brille is distributed in the hope that it will be useful, but
WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY
or FITNESS FOR A PARTICULAR PURPOSE.
See the GNU Affero General Public License for more details.

You should have received a copy of the GNU Affero General Public License
along with brille. If not, see <https://www.gnu.org/licenses/>.            */
#ifndef BRILLE_INTERPOLATOR_HPP_
#define BRILLE_INTERPOLATOR_HPP_
#include <vector>
#include <array>
#include <utility>
#include <mutex>
#include <cassert>
#include <functional>
#include <omp.h>
#include "array_latvec.hpp" // defines bArray
#include "phonon.hpp"
#include "permutation.hpp"
#include "permutation_table.hpp"
#include "approx.hpp"
#include "utilities.hpp"
namespace brille {

//template<class T> using CostFunction = std::function<typename CostTraits<T>::type(ind_t, T*, T*)>;
template<class T>
using CostFunction = std::function<double(brille::ind_t, const T*, const T*)>;

template<class T> struct is_complex {enum{value = false};};
template<class T> struct is_complex<std::complex<T>> {enum {value=true};};
// template<bool C, typename T> using enable_if_t = typename std::enable_if<C,T>::type;

enum class RotatesLike {
  Real, Reciprocal, Axial, Gamma
};

template<class T>
class Interpolator{
public:
  using ind_t = brille::ind_t;
  template<class Z> using element_t =std::array<Z,3>;
  using costfun_t = CostFunction<T>;
  using shape_t = std::vector<ind_t>;
private:
  bArray<T> data_;
  shape_t shape_;
  element_t<ind_t> _elements;
  RotatesLike rotlike_;
  element_t<double> _costmult;
  costfun_t _scalarfun;
  costfun_t _vectorfun;
  //costfun_t _matrixfun; //!< A function to calculate the differences between matrices at two stored points
public:
  explicit Interpolator(size_t scf_type=0, size_t vcf_type=0)
  : data_(0,0), _elements({{0,0,0}}), rotlike_{RotatesLike::Real}, _costmult({{1,1,1}})
  {
    this->set_cost_info(scf_type, vcf_type);
  }
  Interpolator(costfun_t scf, costfun_t vcf)
  : data_(0,0), _elements({{0,0,0}}), rotlike_{RotatesLike::Real},
    _costmult({{1,1,1}}), _scalarfun(scf), _vectorfun(vcf)
  {}
  Interpolator(bArray<T>& d, shape_t sh, element_t<ind_t> el, RotatesLike rl)
  : data_(d), shape_(sh), _elements(el), rotlike_{rl}, _costmult({{1,1,1}})
  {
    this->set_cost_info(0,0);
    this->check_elements();
  }
  Interpolator(bArray<T>& d, shape_t sh, element_t<ind_t> el, RotatesLike rl, size_t csf, size_t cvf, element_t<double> wg)
  : data_(d), shape_(sh), _elements(el), rotlike_{rl}, _costmult(wg)
  {
    this->set_cost_info(csf, cvf);
    this->check_elements();
  }
  // use the Array2<T>(const Array<T>&) constructor
  Interpolator(brille::Array<T>& d, element_t<ind_t> el, RotatesLike rl)
  : data_(d), shape_(d.shape()), _elements(el), rotlike_{rl}, _costmult({{1,1,1}})
  {
    this->set_cost_info(0,0);
    this->check_elements();
  }
  // use the Array2<T>(const Array<T>&) constructor
  Interpolator(brille::Array<T>& d, element_t<ind_t> el, RotatesLike rl, size_t csf, size_t cvf, element_t<double> wg)
  : data_(d), shape_(d.shape()), _elements(el), rotlike_{rl}, _costmult(wg)
  {
    this->set_cost_info(csf, cvf);
    this->check_elements();
  }
  //
  void setup_fake(const ind_t sz, const ind_t br){
    data_ = bArray<T>(sz, br);
    shape_ = {sz, br};
    _elements = {1u,0u,0u};
  }
  //
  void set_cost_info(const int scf, const int vcf){
    switch (scf){
      default:
      this->_scalarfun = [](ind_t n, const T* i, const T* j){
        double s{0};
        for (ind_t z=0; z<n; ++z) s += brille::utils::magnitude(i[z]-j[z]);
        return s;
      };
    }
    switch (vcf){
      case 1:
      debug_update("selecting brille::utils::vector_distance");
      this->_vectorfun = [](ind_t n, const T* i, const T* j){
        return brille::utils::vector_distance(n, i, j);
      };
      break;
      case 2:
      debug_update("selecting 1-brille::utils::vector_product");
      this->_vectorfun = [](ind_t n, const T* i, const T* j){
        return 1-brille::utils::vector_product(n, i, j);
      };
      break;
      case 3:
      debug_update("selecting brille::utils::vector_angle");
      this->_vectorfun = [](ind_t n, const T* i, const T* j){
        return brille::utils::vector_angle(n, i, j);
      };
      break;
      case 4:
      debug_update("selecting brille::utils::hermitian_angle");
      this->_vectorfun = [](ind_t n, const T* i, const T* j){
         return brille::utils::hermitian_angle(n,i,j);
      };
      break;
      default:
      debug_update("selecting sin**2(brille::utils::hermitian_angle)");
      // this->_vectorfun = [](ind_t n, T* i, T* j){return std::abs(std::sin(brille::utils::hermitian_angle(n, i, j)));};
      this->_vectorfun = [](ind_t n, const T* i, const T* j){
        auto sin_theta_H = std::sin(brille::utils::hermitian_angle(n, i, j));
        return sin_theta_H*sin_theta_H;
      };
    }
  }
  void set_cost_info(const int scf, const int vcf, const element_t<double>& elcost){
    _costmult = elcost;
    this->set_cost_info(scf, vcf);
  }
  //
  ind_t size(void) const {return data_.size(0);}
  ind_t branches(void) const {
    /*
    The data Array is allowed to be anywhere from 1 to 5 dimensional.
    Possible shapes are:
      (P,)            - one branch with one scalar per point
      (P, X)          - one branch with X elements per point
      (P, B, Y)       - B branches with Y elements per branch per point
      (P, B, V, 3)    - B branches with V 3-vectors per branch per point
      (P, B, M, 3, 3) - B branches with M (3,3) matrices per branch per point
    In the two-dimensional case there is a potential ambiguity in that X counts
    both the number of branches and the number of elements.
      X = B*Y
    where Y is the sum of scalar, vector, and matrix elements per branch
      Y = S + 3*V + 9*M
    */
    ind_t nd = shape_.size();
    ind_t b = nd>1 ? shape_[1] : 1u;
    if (2 == nd){
      ind_t y = std::accumulate(_elements.begin(), _elements.end(), ind_t(0));
      // zero-elements is the special (initial) case → 1 scalar per branch
      if (y > 0) b /= y;
    }
    return b;
  }
  bool only_vector_or_matrix(void) const {
    // if V or M can not be deduced directly from the shape of data_
    // then this Interpolator represents mixed data
    // purely-scalar data is also classed as 'mixed' for our purposes
    ind_t nd = shape_.size();
    if (5u == nd && 3u == shape_[4] && 3u == shape_[3]) return true;
    if (4u == nd && 3u == shape_[2]) return true;
    if (nd < 4u) return false; // (P,), (P,X), (P,B,Y)
    std::string msg = "Interpolator can not handle a {";
    for (auto x: shape_) msg += " " + std::to_string(x);
    msg += " } data array";
    throw std::runtime_error(msg);
  }
  const bArray<T>& data(void) const {return data_;}
  shape_t shape(void) const {return shape_;};
  brille::Array<T> array(void) const {return brille::Array<T>(data_,shape_);}
  element_t<ind_t> elements(void) const {return _elements;}
  //
  template<class... Args>
  void interpolate_at(Args... args) const {
      this->interpolate_at_mix(args...);
  }
  //
  template<class R, class RotT>
  bool rotate_in_place(bArray<T>& x,
                       const LQVec<R>& q,
                       const RotT& rt,
                       const PointSymmetry& ps,
                       const std::vector<size_t>& r,
                       const std::vector<size_t>& invr,
                       const int nth) const {
    switch (rotlike_){
      case RotatesLike::Real:       return this->rip_real(x,ps,r,invr,nth);
      case RotatesLike::Axial:      return this->rip_axial(x,ps,r,invr,nth);
      case RotatesLike::Reciprocal: return this->rip_recip(x,ps,r,invr,nth);
      case RotatesLike::Gamma:      return this->rip_gamma(x,q,rt,ps,r,invr,nth);
      default: throw std::runtime_error("Impossible RotatesLike value!");
    }
  }
  //
  RotatesLike rotateslike() const { return rotlike_; }
  RotatesLike rotateslike(const RotatesLike a) {
    rotlike_ = a;
    return rotlike_;
  }
  // Replace the data within this object.
  template<class I>
  void replace_data(
      const bArray<T>& nd,
      const shape_t sh,
      const std::array<I,3>& ne,
      const RotatesLike rl = RotatesLike::Real)
  {
    data_ = nd;
    shape_ = sh;
    rotlike_ = rl;
    // convert the elements datatype as necessary
    if (ne[1]%3)
      throw std::logic_error("Vectors must have 3N elements per branch");
    if (ne[2]%9)
      throw std::logic_error("Matrices must have 9N elements per branch");
    for (size_t i=0; i<3u; ++i) _elements[i] = static_cast<ind_t>(ne[i]);
    this->check_elements();
  }
  template<class I>
  void replace_data(
      const brille::Array<T>& nd,
      const std::array<I,3>& ne,
      const RotatesLike rl = RotatesLike::Real)
  {
    data_ = bArray<T>(nd);
    shape_ = nd.shape();
    rotlike_ = rl;
    // convert the elements datatype as necessary
    if (ne[1]%3)
      throw std::logic_error("Vectors must have 3N elements per branch");
    if (ne[2]%9)
      throw std::logic_error("Matrices must have 9N elements per branch");
    for (size_t i=0; i<3u; ++i) _elements[i] = static_cast<ind_t>(ne[i]);
    this->check_elements();
  }
  // Replace the data in this object without specifying the data shape or its elements
  // this variant is necessary since the template specialization above can not have a default value for the elements
  template<template<class> class A>
  void replace_data(const A<T>& nd){
    return this->replace_data(nd, element_t<ind_t>({{0,0,0}}));
  }
  ind_t branch_span() const { return this->branch_span(_elements);}
  //
  std::string to_string() const {
    std::string str= "{ ";
    for (auto s: shape_) str += std::to_string(s) + " ";
    str += "} data";
    auto b = this->branches();
    if (b){
      str += " with " + std::to_string(b) + " mode";
      if (b>1) str += "s";
    }
    auto n = std::count_if(_elements.begin(), _elements.end(), [](ind_t a){return a>0;});
    if (n){
      str += " of ";
      std::array<std::string,3> types{"scalar", "vector", "matrix"};
      for (size_t i=0; i<3u; ++i) if (_elements[i]) {
        str += std::to_string(_elements[i]) + " " + types[i];
        if (--n>1) str += ", ";
        if (1==n) str += " and ";
      }
      str += " element";
      if (this->branch_span()>1) str += "s";
    }
    return str;
  }

  template<class S>
  void add_cost(const ind_t, const ind_t, std::vector<S>&, bool) const;

  template<typename I>
  bool any_equal_modes(const I idx) const {
    return this->any_equal_modes_(static_cast<ind_t>(idx), this->branches(), this->branch_span());
  }
  size_t bytes_per_point() const {
    size_t n_elements = data_.numel()/data_.size(0);
    return n_elements * sizeof(T);
  }
private:
  void check_elements(void){
    // check the input for correctness
    ind_t x = this->branch_span(_elements);
    switch (shape_.size()) {
      case 1u: // 1 scalar per branch per point
        if (0u == x) x = _elements[0] = 1u;
        if (x > 1u) throw std::runtime_error("1-D data must represent one scalar per point!") ;
        break;
      case 2u: // (P, B*Y)
        if (0u == x) x = _elements[0] = shape_[1]; // one branch with y scalars per point
        if (shape_[1] % x)
          throw std::runtime_error("2-D data requires an integer number of branches!");
        break;
      case 3u: // (P, B, Y)
        if (0u == x) x = _elements[0] = shape_[2];
        if (shape_[2] != x)
          throw std::runtime_error("3-D data requires that the last dimension matches the specified number of elements!");
        break;
      case 4u: // (P, B, V, 3)
        if (3u != shape_[3])
          throw std::runtime_error("4-D data can only be 3-vectors");
        if (0u == x) x = _elements[1] = shape_[2]*3u;
        if (shape_[2]*3u != x)
          throw std::runtime_error("4-D data requires that the last two dimensions match the specified number of vector elements!");
        break;
      case 5u: // (P, B, M, 3, 3)
        if (3u != shape_[3] || 3u != shape_[4])
          throw std::runtime_error("5-D data can only be matrices");
        if (0u == x) x = _elements[2] = shape_[2]*9u;
        if (shape_[2]*9u != x)
          throw std::runtime_error("5-D data requires the last three dimensions match the specified number of matrix elements!");
        break;
      default: // higher dimensions not (yet) supported
        throw std::runtime_error("Interpolator data is expected to be 1- to 5-D");
    }
  }
  bool any_equal_modes_(const ind_t idx, const ind_t b_, const ind_t s_) {
    // since we're probably only using this when the data is provided and
    // most eigenproblem solvers sort their output by eigenvalue magnitude it is
    // most-likely for mode i and mode i+1 to be equal.
    // ∴ search (i,j), (i+1,j+1), (i+2,j+2), ..., i ∈ (0,N], j ∈ (1,N]
    // for each j = i+1, i+2, i+3, ..., i+N-1
    if (b_ < 2) return false;
    // data_ is always 2D: (N,1), (N,B), or (N,Y)
    for (ind_t offset=1; offset < b_; ++offset)
    for (ind_t i=0, j=offset; j < b_; ++i, ++j)
    if (brille::approx::vector(s_, data_.ptr(idx, i*s_), data_.ptr(idx, j*s_))) return true;
    // no matches
    return false;
  }
  template<typename I> ind_t branch_span(const std::array<I,3>& e) const {
    return static_cast<ind_t>(e[0])+static_cast<ind_t>(e[1])+static_cast<ind_t>(e[2]);
  }
  element_t<ind_t> count_scalars_vectors_matrices(void) const {
    element_t<ind_t> no{_elements[0], _elements[1]/3u, _elements[2]/9u};
    return no;
  }
  // the 'mixed' variants of the rotate_in_place implementations
  bool rip_real(bArray<T>&, const PointSymmetry&, const std::vector<size_t>&, const std::vector<size_t>&, const int) const;
  bool rip_recip(bArray<T>&, const PointSymmetry&, const std::vector<size_t>&, const std::vector<size_t>&, const int) const;
  bool rip_axial(bArray<T>&, const PointSymmetry&, const std::vector<size_t>&, const std::vector<size_t>&, const int) const;
  template<class R>
  bool rip_gamma_complex(bArray<T>&, const LQVec<R>&, const GammaTable&, const PointSymmetry&, const std::vector<size_t>&, const std::vector<size_t>&, const int) const;
  template<class R, class S=T>
  enable_if_t<is_complex<S>::value, bool>
  rip_gamma(bArray<T>& x, const LQVec<R>& q, const GammaTable& gt, const PointSymmetry& ps, const std::vector<size_t>& r, const std::vector<size_t>& ir, const int nth) const{
    return rip_gamma_complex(x, q, gt, ps, r, ir, nth);
  }
  template<class R, class S=T>
  enable_if_t<!is_complex<S>::value, bool>
  rip_gamma(bArray<T>&, const LQVec<R>&, const GammaTable&, const PointSymmetry&, const std::vector<size_t>&, const std::vector<size_t>&, const int) const{
    throw std::runtime_error("RotatesLike == Gamma requires complex valued data!");
  }

  // interpolate_at_*
  void interpolate_at_mix(const std::vector<std::vector<ind_t>>&, const std::vector<ind_t>&, const std::vector<double>&, bArray<T>&, const ind_t, const bool) const;
  void interpolate_at_mix(const std::vector<std::vector<ind_t>>&, const std::vector<std::pair<ind_t,double>>&, bArray<T>&, const ind_t, const bool) const;
};

#include "interpolator2_at.tpp"
#include "interpolator2_axial.tpp"
#include "interpolator2_cost.tpp"
#include "interpolator2_gamma.tpp"
#include "interpolator2_real.tpp"
#include "interpolator2_recip.tpp"

} // namespace brille
#endif