//
// Adaptation of ITensor bondgate.h for 3-sites gates
//    
#ifndef __BONDGATE3_H_
#define __BONDGATE3_H_

#include "itensor/iqtensor.h"
#include "itensor/mps/siteset.h"

using namespace itensor;

template <class Tensor>
class BondGate3;

using Gate3 = BondGate3<ITensor>;
using IQGate3 = BondGate3<IQTensor>;

template <class Tensor>
class BondGate3
    {
    public:

    enum Type { tReal,  //real-time gate
                tImag,  //imaginary-time gate
                Custom };

    BondGate3(SiteSet const& sites, 
             int i1);

    BondGate3(SiteSet const& sites, 
             int i1, 
             Type type, 
             Real tau, 
             Tensor bondH);

    BondGate3(SiteSet const& sites, 
             int i1, 
             Tensor gate);

    int i1() const { return i1_; }

    int i2() const { return i2_; }

    int i3() const { return i3_; }

    operator const Tensor&() const { return gate_; }

    Tensor const&
    gate() const { return gate_; }

    Type
    type() const { return type_; }

    private:

    Type type_;
    int i1_,i2_,i3_; // The indices of bond from left to right
    Tensor gate_;

    };

template<class Tensor>
Tensor
operator*(BondGate3<Tensor> const& G, Tensor T) { T *= G.gate(); return T; }

template<class Tensor>
Tensor
operator*(Tensor T, BondGate3<Tensor> const& G) { T *= G.gate(); return T; }


template <class Tensor>
BondGate3<Tensor>::
BondGate3(SiteSet const& sites, 
         int i1, 
         Type type, 
         Real tau, 
         Tensor bondH)
  : type_(type)
    {
    i1_ = i1;
    i2_ = i1+1;
    i3_ = i1+2;

    if(!(type_ == tReal || type_ ==tImag))
        {
        Error("When providing bondH, type must be tReal or tImag");
        }
    bondH *= -tau;
    Tensor unit = sites.op("Id",i1_)*sites.op("Id",i2_)*sites.op("Id",i3_);
    if(type_ == tReal)
        {
        bondH *= Complex_i;
        }
    auto term = bondH;
    bondH.mapprime(1,2);
    bondH.mapprime(0,1);

    // exp(x) = 1 + x +  x^2/2! + x^3/3! ..
    // = 1 + x * (1 + x/2 *(1 + x/3 * (...
    // ~ ((x/3 + 1) * x/2 + 1) * x + 1
    for(int ord = 200; ord >= 1; --ord)
        {
        term /= ord;
        gate_ = unit + term;
        term = gate_ * bondH;
        term.mapprime(2,1);
        }
    }

template <class Tensor>
BondGate3<Tensor>::
BondGate3(SiteSet const& sites, 
         int i1, 
         Tensor gate)
  : type_(Custom)
    {
    i1_ = i1;
    i2_ = i1+1;
    i3_ = i1+2;
    gate_ = gate;
    }



#endif
