/*-----------------------------------------------------------------------------

Copyright (C) 2004, 2005, 2006, 2007.

A. Ronald Gallant
Post Office Box 659
Chapel Hill NC 27514-0659
USA   

This program is free software; you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation; either version 2 of the License, or
(at your option) any later version.

This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
GNU General Public License for more details.

You should have received a copy of the GNU General Public License along
with this program; if not, write to the Free Software Foundation, Inc.,
51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.

-----------------------------------------------------------------------------*/
#undef USR_OBJFUN_TYPE_IMPLEMENTED
#undef USR_PROPOSAL_TYPE_IMPLEMENTED

#include "libscl.h"
#include "libsnp.h"
#include "libsmm.h"
#include "snp.h"
#include "emm.h"

using namespace scl;
using namespace libsnp;
using namespace libsmm;
using namespace snp;
using namespace emm;
using namespace std;

namespace {
  void output(const estblock& est_blk,ostream& detail,INTEGER ifile,
      string prefix, const realmat& rho_sim, const realmat& stats_sim, 
      const realmat& pi_sim, realmat reject, const realmat& rho_hat, 
      const realmat& V_hat, INTEGER n, const realmat& rho_mean, 
      const realmat& rho_mode, REAL post_high, const realmat& foc_hat, 
      const realmat& I, const realmat& invJ, INTEGER reps);
}

int main(int argc, char** argp, char** envp)
{

  istream* ctrl_ptr;
  if (argc == 2) {
    ctrl_ptr = new(nothrow) ifstream(argp[1]);
    string msg = "Error, emm, " + string(argp[1]) + " open failed";
    if( (ctrl_ptr == 0) || (!*ctrl_ptr) ) error(msg);
  }
  else {
    ctrl_ptr = new(nothrow) ifstream("control.dat");
    string msg = "Error, emm, control.dat open failed";
    if( (ctrl_ptr == 0) || (!*ctrl_ptr) ) error(msg);
  }
  istream& ctrl = *ctrl_ptr;

  string parmfile;
  string prefix;
 
  while(ctrl >> parmfile >> prefix) {

    string filename = prefix + ".detail.dat";

    ofstream detail_ofs(filename.c_str());
    if (!detail_ofs) error("Error, emm, detail.dat open failed");
    ostream& detail = detail_ofs;

    string pfline;

    vector<string> pfvec;
    ifstream pf_ifs(parmfile.c_str());
    if (!pf_ifs) error("Error, emm, cannot open parmfile " + parmfile);
    while (getline(pf_ifs, pfline)) pfvec.push_back(pfline);

    emmparms pf;

    if(!pf.set_parms(pfvec,detail)) {
       detail.flush();
       error("Error, emm, cannot read EMM parmfile");
    }

    estblock est_blk = pf.get_estblock();
    datblock dat_blk = pf.get_datblock();
    modblock mod_blk = pf.get_modblock();
    objblock obj_blk = pf.get_objblock();

    realmat data; 
    if (!dat_blk.read_data(data)) {
      error("Error, emmparm, cannot read data, dsn = " + dat_blk.dsn);
    }

    if (est_blk.print) {
      detail << starbox("/First 12 observations//");
      detail << data("",seq(1,12));
      detail << starbox("/Last 12 observations//");
      detail << data("",seq(dat_blk.n-11, dat_blk.n));
      detail.flush();
    }

    vector<string> mod_pfvec;
    if (mod_blk.is_mod_parmfile) {
      ifstream mod_pf_ifs(mod_blk.mod_parmfile.c_str());
      if (!mod_pf_ifs) error("Error, emm, cannot open "+mod_blk.mod_parmfile);
      while (getline(mod_pf_ifs, pfline)) mod_pfvec.push_back(pfline);
    }

    usrmod_type usrmod 
      (data, mod_blk.len_mod_parm, mod_blk.len_mod_func, 
       mod_pfvec, pf.get_mod_alvec(), detail);

    proposal_base* proposal_ptr;

    if (est_blk.proptype == 0) {
      if (est_blk.is_inc_block) {
        proposal_ptr = new(nothrow) grid_group_move(pf.get_prop_groups());
	if (est_blk.print) {
	  grid_group_move(pf.get_prop_groups()).write_proposal(detail);
        }
      }
      else {
        proposal_ptr = new(nothrow) group_move(pf.get_prop_groups());
      }
      if (proposal_ptr == 0) error("Error, emm main, operator new failed.");
    }
    else if (est_blk.proptype == 1) {
      bool print = est_blk.print;
      proposal_ptr 
        = new(nothrow) conditional_move(pf.get_prop_groups(),detail,print);
      if (proposal_ptr == 0) error("Error, emm main, operator new failed.");
    } else {
      #if defined USR_PROPOSAL_TYPE_IMPLEMENTED
        proposal_ptr = new(nothrow) proposal_type(pf.get_prop_groups());
        if (proposal_ptr == 0) error("Error, emm main, operator new failed.");
      #else
        error("Error, emm, no user proposal type implemented");
        proposal_ptr = new(nothrow) group_move(pf.get_prop_groups());
        if (proposal_ptr == 0) error("Error, emm main, operator new failed.");
      #endif
    }

    proposal_base& proposal = *proposal_ptr;

    vector<string> obj_pfvec;
    if (obj_blk.is_obj_parmfile) {
      ifstream obj_pf_ifs(obj_blk.obj_parmfile.c_str());
      if (!obj_pf_ifs) error("Error, emm, cannot open "+obj_blk.obj_parmfile);
      while (getline(obj_pf_ifs, pfline)) obj_pfvec.push_back(pfline);
    }

    objfun_base* objfun_ptr;

    if (est_blk.objtype == 0) {

      if (est_blk.print) {
        detail << starbox("/SNP now controls print//");
        detail.flush();
      }

      objfun_ptr = new(nothrow) 
        emm_objfun
          (data, mod_blk.len_mod_parm, mod_blk.len_mod_func, 
           mod_pfvec, pf.get_mod_alvec(),
           obj_pfvec, pf.get_obj_alvec(), detail);
      if (objfun_ptr == 0) error("Error, emm main, operator new failed.");

      if (est_blk.print) {
        detail << starbox("/EMM now controls print//");
        detail.flush();
      }
    }
    else if (est_blk.objtype == 1) {

      if (est_blk.print) {
        detail << starbox("/The mle_objfun constructor now controls print//");
        detail << '\n'
	 << "\tThe mle_objfun constructor instantiates its own copy\n"
	 << "\tof usrmod.  The following output, if any, is from that\n" 
	 << "\tinstantiation.\n";
        detail.flush();
      }

      objfun_ptr = new(nothrow) 
        mle_objfun
          (data, mod_blk.len_mod_parm, mod_blk.len_mod_func, 
           mod_pfvec, pf.get_mod_alvec(),
           obj_pfvec, pf.get_obj_alvec(), detail);
      if (objfun_ptr == 0) error("Error, emm main, operator new failed.");

      if (est_blk.print) {
        detail << starbox("/EMM now controls print//");
        detail.flush();
      }
    }
    else {
      #if defined USR_OBJFUN_TYPE_IMPLEMENTED
        objfun_ptr = new(nothrow) objfun_type
          (data, mod_blk.len_mod_parm, mod_blk.len_mod_func, 
           mod_pfvec, pf.get_mod_alvec(),
           obj_pfvec, pf.get_obj_alvec(), detail);
        if (objfun_ptr == 0) error("Error, emm main, operator new failed.");
      #else
        error("Error, emm, no user objfun type implemented");
        objfun_ptr = new(nothrow) mle_objfun
          (data, mod_blk.len_mod_parm, mod_blk.len_mod_func, 
           mod_pfvec, pf.get_mod_alvec(),
           obj_pfvec, pf.get_obj_alvec(), detail);
        if (objfun_ptr == 0) error("Error, emm main, operator new failed.");
      #endif
    }

    objfun_base& objfun = *objfun_ptr;

    INTEGER cache_size = est_blk.max_cache_size;

    cachemgr cachemgr(usrmod, objfun, cache_size);

    filename = prefix + ".emmcache.dat";
    if (!cachemgr.read_cache(filename.c_str())) { 
      if (cache_size > 0) {
        warn("Warning, emm, emmcache read failed, using default instead");
      }
    }

    smm_mcmc mcmc(proposal, cachemgr, usrmod);
    mcmc.set_simulation_size(est_blk.lchain);
    mcmc.set_stride(est_blk.stride);
    mcmc.set_draw_from_posterior(!est_blk.draw_from_prior);
    mcmc.set_temp(est_blk.temperature);

    asymptotics_base* asymptotics_ptr;

    if (est_blk.kilse) {
      asymptotics_ptr 
        = new(nothrow) minimal_asymptotics(data,usrmod,mcmc);
      if (asymptotics_ptr == 0) error("Error, emm main, operator new failed.");
    }
    else {
      asymptotics_ptr 
        = new(nothrow) bstrap_asymptotics(data, usrmod, objfun, mcmc);
      if (asymptotics_ptr == 0) error("Error, emm main, operator new failed.");
    }

    asymptotics_base& asymptotics = *asymptotics_ptr;

    INT_32BIT seed = est_blk.seed;
    realmat rho = pf.get_rho();

    realmat rho_sim;
    realmat stats_sim;
    realmat pi_sim;
    realmat reject;

    for (INTEGER ifile = 0; ifile <= est_blk.nfile; ++ifile) {

      reject = mcmc.draw(seed, rho, rho_sim, stats_sim, pi_sim);

      filename = prefix + ".emmcache.new";
      if (!cachemgr.write_cache(filename.c_str())) {
        warn("Warning, emm, emmcache write failed");
      }

      realmat rho_hat;
      realmat V_hat;
      INTEGER n;
  
      realmat rho_mean;
      realmat rho_mode;
      REAL post_high;
      realmat foc_hat;
      realmat I;
      realmat invJ;
      INTEGER reps = 0;
  
      asymptotics.set_asymptotics(rho_sim);
      asymptotics.get_asymptotics(rho_hat,V_hat,n);
      asymptotics.get_asymptotics(rho_mean,rho_mode,post_high, 
                                    I,invJ,foc_hat,reps);
    
      usrmod.set_rho(rho_mode);
      realmat usr_sim, usr_stats;
      usrmod.gen_sim(usr_sim, usr_stats);
      objfun(rho_mode,usr_sim,usr_stats);
      filename = prefix + ".diagnostics.dat";
      objfun.write_diagnostics(filename.c_str());
      
      pf.write_parms(parmfile, prefix, seed, rho, rho_mode, invJ/n);
      usrmod.set_rho(rho_mode);
      filename = prefix + ".usrvar.dat";
      usrmod.write_usrvar(filename.c_str());
  
      output(est_blk, detail, ifile, prefix,
        rho_sim, stats_sim, pi_sim, reject, rho_hat, V_hat, n, 
        rho_mean, rho_mode, post_high, foc_hat, I, invJ, reps);

      if (est_blk.print && est_blk.max_cache_size > 0) {
        detail << '\n';
        detail << "\t Cache hit rate = " << cachemgr.cache_hit_rate() << '\n';
        detail.flush();
      }
    }
    delete objfun_ptr;
    delete proposal_ptr;
    delete asymptotics_ptr;
  }
  delete ctrl_ptr;
  return 0;
}

namespace {

  void output(const estblock& est_blk,ostream& detail,INTEGER ifile,
      string prefix, const realmat& rho_sim, const realmat& stats_sim, 
      const realmat& pi_sim, realmat reject, const realmat& rho_hat, 
      const realmat& V_hat, INTEGER n, const realmat& rho_mean, 
      const realmat& rho_mode, REAL post_high, const realmat& foc_hat, 
      const realmat& I, const realmat& invJ, INTEGER reps)
  {
    if (ifile > 999) error("Error, emm, output, nfile too big");

    string filename;
  
    /*
    filename = prefix + ".rho_hat.dat";
    vecwrite(filename.c_str(), rho_hat);
    */

    filename = prefix + ".rho_mean.dat";
    vecwrite(filename.c_str(), rho_mean);
  
    filename = prefix + ".rho_mode.dat";
    vecwrite(filename.c_str(), rho_mode);
  
    realmat V_hat_hess = invJ/n;

    filename = prefix + ".V_hat_hess.dat";
    vecwrite(filename.c_str(), V_hat_hess);
  
    realmat V_hat_info;

    if (!est_blk.kilse) {
      
      filename = prefix + ".V_hat_sand.dat";
      vecwrite(filename.c_str(), V_hat);

      V_hat_info = inv(I)/n;

      filename = prefix + ".V_hat_info.dat";
      vecwrite(filename.c_str(), V_hat_info);

      if (est_blk.print) {
        detail << starbox("/Get asymptotics/results are cumulative//") << '\n';
        detail << "\t ifile = " << ifile << '\n';
        detail << '\n';
        detail << "\t rho_hat = " << rho_hat << '\n';
        detail << "\t V_hat = " << V_hat << '\n';
        detail << "\t n = " << n << '\n';
        detail << "\t rho_mean = " << rho_mean << '\n';
        detail << "\t rho_mode = " << rho_mode << '\n';
        detail << "\t post_high = " << post_high << '\n';
        detail << "\t I = " << I << '\n';
        detail << "\t invJ = " << invJ << '\n';
        detail << "\t foc_hat = " << foc_hat << '\n';
        detail << "\t reps = " << reps << '\n';
        detail.flush();
      }
  
    }
  
    filename = prefix + ".summary.dat";
    ofstream summary_ofs(filename.c_str());
    if (summary_ofs) {
      summary_ofs 
        << "   parm"
        << "     rhomean"
        << "     rhomode"
        << "      sesand"
        << "      sehess"
        << "      seinfo"
        << '\n';
      for (INTEGER i=1; i<=rho_hat.size(); ++i) {
        if (est_blk.kilse) {
          summary_ofs
            << fmt('i',7,i)
            << fmt('g',12,5,rho_mean[i])
            << fmt('g',12,5,rho_mode[i])
            << "            "
            << fmt('g',12,5,sqrt(V_hat_hess(i,i)))
            << "            "
            << '\n';
        }
        else {
          summary_ofs
            << fmt('i',7,i)
            << fmt('g',12,5,rho_mean[i])
            << fmt('g',12,5,rho_mode[i])
            << fmt('g',12,5,sqrt(V_hat(i,i)))
            << fmt('g',12,5,sqrt(V_hat_hess(i,i)))
            << fmt('g',12,5,sqrt(V_hat_info(i,i)))
            << '\n';
        }
      }
      summary_ofs << '\n';
      summary_ofs << "The log posterior (log prior - objfun) at the mode ";
      summary_ofs << "is" << fmt('g',12,5,post_high) << ".\n";
      if (est_blk.objtype == 0) {
        summary_ofs << '\n';
        summary_ofs << "For EMM, if usrmod.prior returns only 0 or 1, "; 
        summary_ofs << "objfun is a chi-square \non ltheta - 1 - lrho ";
        summary_ofs << "degrees of freedom.  See SNP parmfile for ltheta.\n";
      }
      if (!est_blk.kilse) {
        summary_ofs << '\n';
        summary_ofs << "The degrees of freedom for seinfo are ";
        summary_ofs << reps - rho_mode.size() << ".\n";
      }
    }

    char number[5];
    sprintf(number,"%03d",ifile);

    filename = prefix + ".rho." + number + ".dat"; 
    vecwrite(filename.c_str(), rho_sim);
    
    filename = prefix + ".stats." + number + ".dat"; 
    vecwrite(filename.c_str(), stats_sim);
    
    filename = prefix + ".pi." + number + ".dat"; 
    vecwrite(filename.c_str(), pi_sim);
      
    filename = prefix + ".reject." + number + ".dat"; 
    vecwrite(filename.c_str(), reject);
  }
}

