/*-----------------------------------------------------------------------------

Copyright (C) 2005, 2006, 2007.

A. Ronald Gallant
Post Office Box 659
Chapel Hill NC 27514-0659
USA   

This program is free software; you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation; either version 2 of the License, or
(at your option) any later version.

This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
GNU General Public License for more details.

You should have received a copy of the GNU General Public License along
with this program; if not, write to the Free Software Foundation, Inc.,
51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.

-----------------------------------------------------------------------------*/

#include "libsmm.h"
#include "emm.h"

using namespace scl;
using namespace libsmm;
using namespace emm;
using namespace std;

emm::elec_usrmod::elec_usrmod
  (const realmat& dat, INTEGER len_mod_parm, INTEGER len_mod_func,
   const std::vector<std::string>& mod_pfvec,
   const std::vector<std::string>& mod_alvec, 
   std::ostream& detail)
: data(dat), rho(), blen(23), lrho(11), lstats(1), 
  variable_seed(740726), a(), B(), R()
{
  vector<string>::const_iterator usr_ptr = mod_alvec.begin();
  ++usr_ptr;
  blen = atoi((usr_ptr++)->substr(0,12).c_str());

  if (lrho != len_mod_parm) {
    error("Error, usrmod, constructor, len_mod_parm is set wrong in parmfile");
  }

  if (lstats != len_mod_func) {
    error("Error, usrmod, constructor, len_mod_func is set wrong in parmfile");
  }

  if (blen < 2*lrho+1) {
    blen = 2*lrho+1;
    warn("Warning: usrmod, constructor, blen increaed to 2*lrho+1");
  }
	     
}

void emm::elec_usrmod::set_parms()
{
  a.resize(3,1);
  B.resize(3,3);
  R.resize(2,2);

  a[1] = rho[1];
  a[2] = rho[2];

  B(1,1) = rho[3];
  B(1,2) = rho[4];
  B(2,2) = rho[5];
  B(1,3) = rho[6];
  B(2,3) = rho[7];
  B(3,3) = rho[8];

  R(1,1) = rho[9];
  R(1,2) = rho[10];
  R(2,2) = rho[11];

  a[3] = -1.0;

  B(2,1) = B(1,2);
  B(3,1) = B(1,3);
  B(3,2) = B(2,3);

  R(2,1) = 0.0;
}

den_val emm::elec_usrmod::likelihood(realmat& yhat, realmat& zhat) 
{
  if (!support(rho)) return den_val(false,-REAL_MAX);;

  INTEGER r = data.get_rows();
  INTEGER n = data.get_cols();

  if (r != 5) error("Error, elec_usrmod, likelihood, bad data");
  if (rho.get_rows() != 11) error("Error, elec_usrmod, likelihood, bad parm");

  realmat y = data("1:2","");
  realmat x = data("3:5","");

  realmat s = B*x;
  for (INTEGER t=1; t<=n; ++t) {
    s(1,t) += a[1];
    s(2,t) += a[2];
    s(3,t) += a[3];
  }

  for (INTEGER t=1; t<=n; ++t) {
    for (INTEGER i=1; i<=s.size(); ++i) {
      if (s[i] >= 0.0) return den_val(false,-REAL_MAX);
    }
  }

  yhat.resize(2,n);
  zhat.resize(2,n);

  realmat ehat(2,n);

  for (INTEGER t=1; t<=n; ++t) {
    REAL bot = log(-s(3,t));
    yhat(1,t) = log(-s(1,t)) - bot;
    yhat(2,t) = log(-s(2,t)) - bot;
    ehat(1,t) = y(1,t) - yhat(1,t);
    ehat(2,t) = y(2,t) - yhat(2,t);
  }

  realmat P = inv(R);

  zhat = P*ehat;

  REAL q = 0.0;

  for (INTEGER t=1; t<=n; ++t) {
    q += pow(zhat(1,t),2) + pow(zhat(2,t),2);
  }

  q *= (-0.5);

  REAL detR = R(1,1)*R(2,2);

  q -= REAL(n)*log(detR);

  const REAL pi = 3.14159265358979312e+00;

  q -= REAL(n)*log(sqrt(2.0*pi));

  return den_val(true,q);
}

bool emm::elec_usrmod::support(const realmat& parm) 
{
  if (parm[9] <= 0.0) return false;
  if (parm[11] <= 0.0) return false;

  return true; 
}

den_val emm::elec_usrmod::prior(const realmat& rho_in, const realmat& stats) 
{
  return den_val(true, 0.0);
}

bool emm::elec_usrmod::gen_bootstrap(vector<realmat>& bs)
{
  if (!support(rho)) return false;

  realmat yhat, zhat;

  den_val dv = likelihood(yhat, zhat);

  if (!dv.positive) return false;

  INTEGER len = 2*lrho+1;
  INTEGER len_vec = bs.size();

  if (len_vec != len) {
     bs.resize(len);
  }

  INTEGER n = data.get_cols();
  realmat sim(5,n);
  realmat x, z, e, s;

  for (INTEGER i=0; i<len; ++i) {
    for (INTEGER t=1; t<=n; ++t) {
      x = data("3:5",t);
      INTEGER u = iran(variable_seed, n-1);
      ++u;
      z = zhat("",u);
      e = R*z;
      s = B*x;
      s[1] += a[1];
      s[2] += a[2];
      s[3] += a[3];
      for (INTEGER j=1; j<=s.size(); ++j) {
        if (s[j] >= 0.0) return false;
      }
      REAL bot = log(-s[3]);
      sim(1,t) = log(-s[1]) - bot + e[1];
      sim(2,t) = log(-s[2]) - bot + e[2];
      sim(3,t) = x[1];
      sim(4,t) = x[2];
      sim(5,t) = x[3];
    }
    bs[i] = sim;
  }

  return true;
}

