#include "itensor/all.h"
#include "parameters.h"
#include "bondgate3.h"

using namespace itensor;
using namespace std;

using Gate = BondGate<ITensor>;

using namespace parameters;


// return the 3-site gate, whose leftmost site is "site", involved in the Trotterization of the Hamiltonian
Gate3 three_site_gate(const int site, const SpinHalf& sites, const Real tstep){
    const int N = sites.N();
    if (site+2>N)
        throw runtime_error("Invalid position for a three-site gate");
    
    //NNN term
    auto hterm = J*lam1*toITensor(sites.op("Sz",site)*sites.op("Id",site+1)*sites.op("Sz",site+2));
    hterm += J*lam3*toITensor(sites.op("Sx",site))*sites.op("Id",site+1)*toITensor(sites.op("Sx",site+2));
    hterm += J*lam3*toITensor(sites.op("Sy",site))*sites.op("Id",site+1)*toITensor(sites.op("Sy",site+2));

    //left bond
    auto left_bond = J*toITensor(sites.op("Sz",site)*sites.op("Sz",site+1));
    left_bond += J*(lam2-pairing)*toITensor(sites.op("Sx",site))*toITensor(sites.op("Sx",site+1));
    left_bond += J*(lam2+pairing)*toITensor(sites.op("Sy",site))*toITensor(sites.op("Sy",site+1));

    left_bond += (site%2==0 ? 1. : -1.)*J*yz*toITensor(sites.op("Sy",site))*toITensor(sites.op("Sz",site+1));
    left_bond += (site%2==0 ? 1. : -1.)*J*yz*toITensor(sites.op("Sz",site))*toITensor(sites.op("Sy",site+1));

    left_bond*= toITensor(sites.op("Id",site+2));
    left_bond/=2;


    //right bond
    auto right_bond = J*toITensor(sites.op("Sz",site+1)*sites.op("Sz",site+2));
    right_bond += J*(lam2-pairing)*toITensor(sites.op("Sx",site+1))*toITensor(sites.op("Sx",site+2));
    right_bond += J*(lam2+pairing)*toITensor(sites.op("Sy",site+1))*toITensor(sites.op("Sy",site+2));

    //note the change of sign from left_bond
    right_bond += (site%2==0 ? -1. : 1.)*J*yz*toITensor(sites.op("Sy",site+1))*toITensor(sites.op("Sz",site+2));
    right_bond += (site%2==0 ? -1. : 1.)*J*yz*toITensor(sites.op("Sz",site+1))*toITensor(sites.op("Sy",site+2));

    right_bond*= toITensor(sites.op("Id",site));
    right_bond/=2;

    //left field
    auto left_field = h*toITensor(sites.op("Sx",site));
    left_field += hz*toITensor(sites.op("Sz",site));
    left_field += hy*toITensor(sites.op("Sy",site));
    left_field /= 3;
    left_field *= toITensor(sites.op("Id",site+1)) * toITensor(sites.op("Id",site+2));

    //middle field
    auto middle_field = h*toITensor(sites.op("Sx",site+1));
    middle_field += hz*toITensor(sites.op("Sz",site+1));
    middle_field += hy*toITensor(sites.op("Sy",site+1));
    middle_field /= 3;
    middle_field *= toITensor(sites.op("Id",site)) * toITensor(sites.op("Id",site+2));

    //right field
    auto right_field = h*toITensor(sites.op("Sx",site+2));
    right_field += hz*toITensor(sites.op("Sz",site+2));
    right_field += hy*toITensor(sites.op("Sy",site+2));
    right_field /= 3;
    right_field *= toITensor(sites.op("Id",site)) * toITensor(sites.op("Id",site+1));

    if (site == 1){
        left_field *= 3;
        left_bond *= 2;
        middle_field *= 3./2.;
    }else if(site==2){
        left_field *= 3./2.;
    }else if (site==N-3){
        right_field *= 3./2.;
    }else if (site == N-2){
        right_field *= 3;
        right_bond *= 2;
        middle_field *= 3./2.;
    }

    hterm += left_field + middle_field + right_field;
    hterm += left_bond + right_bond;

    auto g = Gate3(sites,site,Gate3::tReal,tstep,hterm);
    return g;
}

//return the Hamiltonian
MPO Hamiltonian(const SiteSet& sites){
	auto ampo = AutoMPO(sites);
	int N = sites.N();
	for(int i = 1; i <= N-1; i++){
		ampo += J,"Sz",i,"Sz",i+1;
		ampo += J*(lam2-pairing),"Sx",i,"Sx",i+1;
		ampo += J*(lam2+pairing),"Sy",i,"Sy",i+1;
		ampo += (i%2==0 ? 1. : -1.)*J*yz,"Sy",i,"Sz",i+1;
		ampo += (i%2==0 ? 1. : -1.)*J*yz,"Sz",i,"Sy",i+1;
	}

	for (int i=1; i<=N; i++){
		ampo += h,"Sx",i;
		ampo += hz,"Sz",i;
		ampo += hy,"Sy",i;
	}
		

	for (int i=1; i<=N-2; i++){
		ampo += J*lam1,"Sz",i,"Sz",i+2;
		ampo += J*lam3,"Sx",i,"Sx",i+2;
		ampo += J*lam3,"Sy",i,"Sy",i+2;
	}

	auto arg_conversion=Args("Exact",true);
	auto H = toMPO<ITensor>(ampo,arg_conversion);
	return H;
}