#ifndef __TDMRG3_H_
#define __TDMRG3_H_

#include "itensor/all.h"
#include "bondgate3.h"


using namespace itensor;
using namespace std;
//
//Adaptation of itensor TEVOL to compute structure factors
//for Hamiltoninans with next-nearest-neighbour interactions
//
// Evolves an MPS in real or imaginary time by an amount ttotal in steps
// of tstep using the list of bond gates provided.
//
// Arguments recognized:
//    "Verbose": if true, print useful information to stdout
//


template <class Tensor>
void two_split(Tensor& AA, Tensor& MPS_tensor, const Args& args);

template <class Tensor>
void restore_MPS(Tensor& AA, MPSt<Tensor>& psi, int AA_left, int new_centre, const Args& args);

template <class Iterable, class Tensor, class Observer_t>
Real
gateTEvol3(Iterable const& gatelist, 
          Real ttotal, 
          Real tstep, 
          MPSt<Tensor>& psi, 
          Observer_t& obs,
          Args args){
    const bool verbose = args.getBool("Verbose",false);
    const bool normalize = args.getBool("Normalize",true);

    const int nt = int(ttotal/tstep+(1e-9*(ttotal/tstep)));
    if(fabs(nt*tstep-ttotal) > 1E-9){
        Error("Timestep not commensurate with total time");
    }

    if(verbose){
        printfln("Taking %d steps of timestep %.5f, total time %.5f",nt,tstep,ttotal);
    }

    
    Real tot_norm = norm(psi);

    Real tsofar = 0;

    const int N = psi.N();

    //do a first measurement of the static quantities
    args.add("TimeStepNum",0);
    args.add("Time",tsofar);
    args.add("TotalTime", ttotal);
    obs.measure(psi,args);

    psi.orthogonalize(args);

    const int measure_every = args.getInt("MeasureEvery");

    //time evolve
    for(int tt = 1; tt <= nt; ++tt){
        auto g = gatelist.begin();
       
        bool MPS_form = true;
        int AA_left = -1;
        Tensor AA;

        while(g != gatelist.end()){
            auto i1 = g->i1();
            auto i2 = g->i2();
            auto i3 = g->i3();

            if (MPS_form){
                if (psi.orthoCenter()<i1)
                    psi.position(i1);
                else if (psi.orthoCenter()>i3)
                    psi.position(i3);
                //else do nothing
                AA = psi.A(1) * psi.A(2);
                AA *= psi.A(3);
            }else{
                if (AA_left == i1)
                    AA *= psi.A(i3);
                else if (AA_left == i2)
                    AA *= psi.A(i1);
                else
                    throw runtime_error("The state was not returned to MPS form when it should have been");
            }
            AA *= g->gate();
            AA.mapprime(1,0,Site);

            //PrintData(AA.inds());

            ++g;            
            if(g != gatelist.end()){
                //Look ahead to next gate position
                auto ni1 = g->i1();
                auto ni2 = g->i2();
                auto ni3 = g->i3();
                if(ni1 == i2){ //restore A(i1)                     
                    
                    auto p_ind = findtype(psi.A(i1), Site);    
                    if (i1>1){
                        auto l_ind = commonIndex(psi.A(i1-1),AA);
                        psi.Aref(i1) = Tensor(p_ind,l_ind);
                    }else{
                        psi.Aref(i1) = Tensor(p_ind);
                    }
                    two_split(AA,psi.Aref(i1), args);
                    psi.leftLim(i1);
                    MPS_form=false;
                    AA_left = i2;
                }else if (ni2 == i1){ //restore A(i3)
                    auto p_ind = findtype(psi.A(i3), Site);
                    if (i3<N){
                        auto l_ind = commonIndex(psi.A(i3+1),AA);
                        psi.Aref(i3) = Tensor(p_ind,l_ind);
                    }else{
                        psi.Aref(i3) = Tensor(p_ind);
                    }
                    two_split(AA,psi.Aref(i3), args);
                    psi.rightLim(i3);
                    MPS_form=false;
                    AA_left = i1;
                }else{ //restore MPS form
                    if (ni1 >= i3) //put othogonality centre in i3
                        restore_MPS(AA, psi, i1, i3, args);
                    else if (ni3<=i1) //put orthogonality centre in i1
                        restore_MPS(AA, psi, i1, i1, args);
                    else
                        throw runtime_error("Something is wrong here");
            
                    MPS_form=true;
                    AA_left=-1;
                }
            }else{
                restore_MPS(AA, psi, i1, i1, args);
                MPS_form=true;
                AA_left=-1;
            }

        }

        if(normalize)
            tot_norm *= psi.normalize();

        tsofar += tstep;

        args.add("TimeStepNum",tt);
        args.add("Time",tsofar);
        args.add("TotalTime",ttotal);
        if (tt % measure_every == 0)
            obs.measure(psi,args);
    }

    if(verbose)
        printfln("\nTotal time evolved = %.5f\n",tsofar);

    return tot_norm;

} // gateTEvol3

template <class Tensor>
void two_split(Tensor& AA, Tensor& MPS_tensor, const Args& args){
    Tensor D,V;
    double cutoff = args.getReal("Cutoff");
    if (cutoff<1e-12){
        bool normalize=args.getBool("Normalize");
        svd(AA,MPS_tensor,D,V,args);
        if (normalize) {
            D *= 1./itensor::norm(D);
        }
        AA = D*V;
    }
    else{
        denmatDecomp(AA,MPS_tensor,V,Fromleft,args);
        bool normalize=args.getBool("ExpensiveNormalize",false);
        if (normalize){
            auto nrm = itensor::norm(V);
            if(nrm > 1E-16) AA = V/nrm;
        }else{
            AA = V;
        }
        
    }
}

template <class Tensor>
void restore_MPS(Tensor& AA, MPSt<Tensor>& psi, int AA_left, int new_centre, const Args& args){
    int j1,j2,j3;
    if (new_centre == AA_left){
        j1=AA_left;
        j2=j1+1;
        j3=j1+2;
        auto p_ind = findtype(psi.A(j3), Site);
        auto l_ind = commonIndex(psi.A(j3+1),AA);
        psi.Aref(j3) = Tensor(p_ind,l_ind);
    }else if (new_centre == AA_left+2){
        j1=AA_left+2;
        j2=j1-1;
        j3=j1-2;
        auto p_ind = findtype(psi.A(j3), Site);
        auto l_ind = commonIndex(psi.A(j3-1),AA);
        psi.Aref(j3) = Tensor(p_ind,l_ind);
    }else{
        throw runtime_error("Resotring MPS form by putting orthogonality entre in the middle of AA is useless. Make sure you need it");
    }

    two_split(AA, psi.Aref(j3), args);

    auto p_ind = findtype(psi.A(j2), Site);
    auto l_ind = commonIndex(psi.A(j3),AA);
    psi.Aref(j2) = Tensor(p_ind,l_ind);

    two_split(AA, psi.Aref(j2), args);
    psi.Aref(j1) = AA;

     if (new_centre == AA_left){
        psi.rightLim(new_centre+1);
    }else if (new_centre == AA_left+2){
        psi.leftLim(new_centre-1);
    }

}



#endif
