#include "itensor/all.h"
#include "SiteOp.h"
//#include "my_TEBD.h"
#include "tDMRG3.h"
#include <iostream>
#include "DSF_Observer.h"
#include "gates.h"
#include "parameters.h"


using namespace itensor;
using namespace std;
using namespace parameters;


MPO Hamiltonian(const SiteSet& sites);
double magnetization(MPS& psi);

int main(int argc, char* argv[]){

if (argc!=2){ 
	printfln("Usage: %s input_file",argv[0]); 
    return 0; 
}

auto input = InputGroup(argv[1],"input");

const int N = input.getInt("N"); //number of sites

const Real ttotal = input.getReal("T"); //total time to evolve
const Real cutoff = input.getReal("cutoff"); //truncation error cutoff
const Real tstep = input.getReal("dt");

h = input.getReal("hx");
hy = input.getReal("hy");
double Bz = input.getReal("hz");
const double h_z_over_m = input.getReal("hz_over_m");
const double misalignment = input.getReal("misalignment");
pairing = input.getReal("pairing");
yz = input.getReal("yz");
lam1 = input.getReal("NNNZZ");
lam2 = input.getReal("hopping");

const int Xmax = input.getInt("Xmax");
const double m_tol = input.getReal("m_tol");

int component = input.getInt("component");
// 1: xx
// 2: yy
// 3: zz

string measure;
switch(component){
	case 1: measure = "Sx"; break;
	case 2: measure = "Sy"; break;
	case 3: measure = "Sz"; break;
	default: throw runtime_error("Invalid component in input file");
}

int measure_every=input.getInt("measure_every");


///////////////////////////////////////////////////////////////////////////////
//	Compute ground state and determin hz self-consistently	///////////////////
///////////////////////////////////////////////////////////////////////////////
auto sites = SpinHalf(N);
auto state = InitState(sites);
for(int i = 1; i <= N; ++i) 
	state.set(i, "Up");

auto psi = MPS(state);

double m = 0; //m=1 for the |\up \up ... > state
double m_error = 2*m_tol;

while (abs(m_error)>m_tol){

	hz = -(abs(h) + abs(hy)) * misalignment - m * h_z_over_m - Bz;

	auto H = Hamiltonian(sites);

	auto sweeps = Sweeps(20);
	sweeps.maxm() = 50,50,100,200,500;
	sweeps.cutoff() = cutoff;
	sweeps.noise() = 1e-8, 1e-8, 1e-9, 1e-10, 1e-11;

	auto energy = dmrg(psi,H,sweeps,"Quiet");

	//printfln("\nGround State Energy = %.10f",energy);
	//printfln("\nUsing overlap = %.10f\n", overlap(psi,H,psi) );
	printfln("\nEnergy variance = %.10f\n", sqrt(overlap(psi,H,H,psi)-sqr(energy)));

	//measure magnetization
	double new_m = magnetization(psi);
	m_error = new_m - m;
	m = new_m;
	cout << "Magnetization = " << m << "\n";
	cout << "Magnetization error in the last step = " << abs(m_error) << "\n";
}

hz = -(abs(h) + abs(hy)) * misalignment - m * h_z_over_m - Bz;

auto H = Hamiltonian(sites);

auto sweeps = Sweeps(50);
sweeps.maxm() = 50,50,50,Xmax;
sweeps.cutoff() = cutoff;
sweeps.noise() = 1e-7, 1e-7, 1e-8, 1e-9, 1e-10, 1e-11;

auto energy = dmrg(psi,H,sweeps,"Quiet");

//printfln("\nGround State Energy = %.10f",energy);
//printfln("\nUsing overlap = %.10f\n", overlap(psi,H,psi) );
printfln("\nEnergy variance = %.10f\n", sqrt(overlap(psi,H,H,psi)-sqr(energy)));

//measure magnetization
double new_m = magnetization(psi);
m_error = new_m - m;
m = new_m;

cout << "Magnetization = " << m << "\n";
cout << "Magnetization error in the last step = " << abs(m_error) << "\n";


ofstream f_energy;
f_energy.open("energy.txt");
f_energy << tstep <<"\n";
f_energy << energy <<"\n";
f_energy << m << "\n";
f_energy << hz << "\n";
f_energy.close();


///////////////////////////////////////////////
//	Setup time-evolution gates	///////////////
///////////////////////////////////////////////

auto gates = vector<Gate3>();

for(int b = 1; b < N-2; b++){
    gates.push_back(three_site_gate(b,sites,tstep/2));
}
gates.push_back(three_site_gate(N-2,sites,tstep));
for(int b = N-3; b > 0; b--){
    gates.push_back(three_site_gate(b,sites,tstep/2));
}


///////////////////////////////////////////////
//	Setup observables 	///////////////////////
///////////////////////////////////////////////



auto B=vector<SiteOp>(N);
for (int i=1; i<=N; i++){
	auto s = sites(i);
	auto op = sites.op(measure,i);
	B.at(i-1)=SiteOp(sites,i,op);
}

auto s = sites(N/2);
auto op = sites.op(measure,N/2);
auto A=SiteOp(sites,N/2,op);

auto Obs=DSF_Observer(B,psi);


///////////////////////////////////////////////
//	Apply operator on ground state ////////////
///////////////////////////////////////////////

auto psi0=psi;
applySiteOp(A,psi);
psi.position(1);


///////////////////////////////////////////////
//	Finally time-evolve	///////////////////////
///////////////////////////////////////////////

system("rm Tevol_out_*");


gateTEvol3(gates,ttotal,tstep,psi,Obs,{"Normalize",true,"Cutoff=",cutoff,"Verbose=",true,"Maxm",Xmax,"MeasureEvery",measure_every});
printfln("Maximum MPS bond dimension after time evolution is %d",maxM(psi));

/*auto im = complex<double>(0.,1.);
Print(overlapC(psi,psi0));
Print(overlapC(psi,psi0)*exp(-im*energy*ttotal)-1);*/




return 0;
}

double magnetization(MPS& psi){
	auto sites = psi.sites();
	int N = psi.N();
	double m=0;
	for (int i=N/4; i<3*N/4; i++){
		psi.position(i);
		auto A = psi.A(i);
		m += (A * toITensor(sites.op("Sz",i)) * dag(prime(A,Site))).real(); 
	}
	m /= (3*N/4-N/4);
	return 2*m;
}

