#include "libscl.h"
#include "svmod.h"

using namespace std;
using namespace scl;

void svmod::set_parms(const realmat& theta)
{
  if(theta.size() != 11) error("Error, svmod, wrong dimension for theta");
  a0 = theta[1];
  a1 = theta[2];
  b0 = theta[3];
  b11 = theta[4];
  b22 = theta[5];
  r11 = theta[6];
  r21 = theta[7];
  r22 = theta[8];
  r31 = theta[9];
  r32 = theta[10];
  r33 = theta[11];
  T(1,1) = r11;
  T(1,2) = 0.0;
  T(2,1) = r21;
  T(2,2) = r22;
  T = inv(T);
}

realmat svmod::get_parms() const
{
  realmat theta(11,1);
  theta[1] = a0;
  theta[2] = a1;
  theta[3] = b0;
  theta[4] = b11;
  theta[5] = b22;
  theta[6] = r11;
  theta[7] = r21;
  theta[8] = r22;
  theta[9] = r31;
  theta[10] = r32;
  theta[11] = r33;
  return theta;
}

void svmod::draw_x0_y0(realmat& x, realmat& y, INT_32BIT& seed) const
{
  realmat xlag(4,1,0.0);
  realmat ylag(2,1,0.0);
  y.resize(2,1);
  realmat e(2,1);
  for (INTEGER t=1; t<=50; ++t) {
    x = draw_xt(xlag, seed);
    e[1] = x[1] - b0 - b11*x[3];
    e[2] = x[2] - b0 - b22*x[4];
    e = T*e;
    REAL u = r31*e[1] + r32*e[2] + r33*unsk(seed);
    y[1] = a0 + a1*ylag[1] + exp(x[3]+x[4])*u;
    y[2] = ylag[1];
    xlag = x;
    ylag = y;
  }
}

realmat svmod::draw_xt(const realmat& xlag, INT_32BIT& seed) const
{
  realmat x(4,1);
  REAL e1 = unsk(seed);
  REAL e2 = unsk(seed);
  x[1] = b0 + b11*xlag[1] + r11*e1;
  x[2] = b0 + b22*xlag[2] + r21*e1 + r22*e2;
  x[3] = xlag[1];
  x[4] = xlag[2];
  return x;
}

REAL svmod::prob_yt(const realmat& y, const realmat& x) const
{
  const REAL roottwopi = sqrt(6.283195307179587);
  realmat e(2,1);
  e[1] = x[1] - b0 - b11*x[3];
  e[2] = x[2] - b0 - b22*x[4];
  e = T*e;
  REAL sd = exp(x[3]+x[4])*r33;
  REAL mu = a0 + a1*y[2] + exp(x[3]+x[4])*(r31*e[1] + r32*e[2]);
  REAL z = (y[1]-mu)/sd;
  return exp(-0.5*z*z)/(roottwopi*sd);
}

sample svmod::draw_sample(INTEGER n, INT_32BIT& seed) const
{
  sample s(n);
  draw_x0_y0(s.x0, s.y0, seed);
  realmat xlag = s.x0;
  realmat ylag = s.y0;
  realmat x;
  realmat y(2,1);
  realmat e(2,1);
  for (INTEGER t=1; t<=n; ++t) {
    x = draw_xt(xlag, seed);
    e[1] = x[1] - b0 - b11*x[3];
    e[2] = x[2] - b0 - b22*x[4];
    e = T*e;
    REAL u = r31*e[1] + r32*e[2] + r33*unsk(seed);
    y[1] = a0 + a1*ylag[1] + exp(x[3]+x[4])*u;
    y[2] = ylag[1];
    for (INTEGER i=1; i<=4; ++i) {
      s.X(i,t) = x[i];
    }
    for (INTEGER i=1; i<=2; ++i) {
      s.Y(i,t) = y[i];
    }
    xlag = x;
    ylag = y;
  }
  return s;
}
