#!/usr/bin/env python

"""
This script uses WHAM to combine multiple umbrella sampling windows into one histogram
"""

import numpy as np
import copy
import sys

def load_last_hist_file(filename,op_dim,col_id):
	ops = {}
	fin = open(filename,'r')
	for line in fin.readlines():
		if line.strip()[0] != '#':
			vals = line.strip().split()
			ops[tuple([float(xx) for xx in vals[:op_dim]])] = float(vals[col_id])
	return ops

def load_weight_file(filename,op_dim):
	weights = {}
	fin = open(filename,'r')
	for line in fin.readlines():
		if len(line.split()) >= 2:
			vals = line.strip().split()
			weights[tuple([float(xx) for xx in vals[:op_dim]])] = float(vals[op_dim])
	
	return weights		

def wham_process_um2(weight_files, last_hist_files,op_dim=1,column_id=-1):
	#weight_files: list of weight files, last_hists: list of last_Hists files, column_id: column corresponding to the desired temperature in last_hist_file (so, simulation temp is column_id = 2) 
	#Warning! weight files need to be different files. Otherwise the last_hist files need to be merged together 
	windows = len(weight_files)
	weights = []
	last_hists = []
	counters = []
	
	#column id is the unbiased histogram column from simulation
	if column_id == -1:
		column_id = op_dim
	
	for i in xrange(windows):
		new_weight = load_weight_file(weight_files[i],op_dim)
		new_hist = load_last_hist_file(last_hist_files[i],op_dim,column_id)
		
		for mykey in new_hist.keys():
			new_hist[mykey] *= new_weight[mykey]
			
		counter = float(sum(new_hist.values()))
		for mykey in new_hist.keys():
			new_hist[mykey] /= counter

		weights.append(new_weight)
		last_hists.append(new_hist)
		counters.append(counter)
	
	resulting_hist = {}
	keys = last_hists[0].keys()
	for key in keys:
		resulting_hist[key] = 0.	
	
	expminuslogf = np.ones(windows)
	expminuslogfnew = copy.deepcopy(expminuslogf)
	rho = {}
	
	iter = 0
	while ( iter < 10000 and (iter == 0 or max(abs(expminuslogf - expminuslogfnew)) > 1e-6) ):
		iter += 1
		
		expminuslogf = copy.deepcopy(expminuslogfnew)
		
		for op in keys:
			num = 0
			denum = 0
			for i in xrange(windows):
				weight = weights[i][op]
				pom = last_hists[i][op]
				counter = counters[i]
				num += pom * counter  
				denum += weight * counter / expminuslogf[i]
			if denum == 0:
				rho[op] = 0.
			else:
				rho[op] = float(num)/denum
			
		total = sum(rho.values())
		
		for key in rho.keys():
				rho[key] /= float(total)
		
		expminuslogfnew = np.zeros(windows)
		for i in xrange(windows):
			for op in keys:
				expminuslogfnew[i] += weights[i][op] * rho[op]  
	
	
	info_file.write('#Converged after %d iterations\n' % (iter) )
	
	return rho


if len(sys.argv ) < 4 or len(sys.argv) % 2 != 0 :
	print 'Usage %s order_param_dimension weight_file hist_file weight_file hist_file .... ' % (sys.argv[0])
	sys.exit(1)
	
op_dim = int(sys.argv[1])

weight_files = []
hist_files = []

for i in xrange(2,len(sys.argv)):
	if i % 2 == 0:
		weight_files.append(sys.argv[i])
	else:
		hist_files.append(sys.argv[i])
		

info_file = open("wham_information.txt", "w")
info_file.write('#Loaded weight files: %s\n' % weight_files)
info_file.write('#Loaded hist files: %s\n' % hist_files)
info_file.write('#Taking op-dim: %s\n' %op_dim )

result = wham_process_um2(weight_files, hist_files,op_dim)

nkeys = sorted(list(result.keys()))
for key in nkeys:
	for i in xrange(op_dim):
		print key[i],
	print result[key]
	
info_file.close()
