#include "logistic.h"
       
using namespace scl;
using namespace std; 

typedef logistic_neg_log_likelihood lnll;

lnll::logistic_neg_log_likelihood() 
{ 
  ifstream fin("logitsimC.dat");
    if (!fin) error("Error, cannot open logitsimC.dat");
    n = 500; p = 4;
    y.resize(n,1); x1.resize(n,1); x2.resize(n,1); x3.resize(n,1);
    INTEGER t=1;
    while(t<=n) { fin >> y[t] >> x1[t] >> x2[t] >> x3[t]; ++t;}
}

bool lnll::get_F(const realmat& b, realmat& f, realmat& F)
{
  if (b.size() != p) error("Error, b wrong size");
  f.resize(1,1,0.0); 
  F.resize(1,p,0.0);
  for (INTEGER t=1; t<=n; ++t) {
    REAL z = b[1] + b[2]*x1[t] + b[3]*x2[t] + b[4]*x3[t];
    REAL phat = exp(z)/(1.0 + exp(z));
    f[1] -= y[t]*log(phat) + (1.0 - y[t])*log(1.0 - phat);
    REAL d_f_wrt_phat = -y[t]/phat + (1.0 - y[t])/(1.0 - phat);     
    REAL d_phat_wrt_z = exp(z)/(1.0+exp(z)) - pow(exp(z)/(1.0+exp(z)),2);
    F[1] += d_f_wrt_phat*d_phat_wrt_z;
    F[2] += d_f_wrt_phat*d_phat_wrt_z*x1[t];
    F[3] += d_f_wrt_phat*d_phat_wrt_z*x2[t];
    F[4] += d_f_wrt_phat*d_phat_wrt_z*x3[t];
  }
  return true;
}

bool lnll::get_numerical_F(const realmat& b, realmat& f, realmat& F)
{
  if (! get_f(b,f) ) return false;
  return nleqns_base::df(b,F);
}

bool lnll::get_f(const realmat& b, realmat& f)
{
  realmat F; 
  return get_F(b, f, F);
}

realmat lnll::get_information_matrix(const realmat& b)
{
  realmat F(p,1); realmat I(p,p,0.0);
  for (INTEGER t=1; t<=n; ++t) {
    REAL z = b[1] + b[2]*x1[t] + b[3]*x2[t] + b[4]*x3[t];
    REAL phat = exp(z)/(1.0 + exp(z));
    REAL d_f_wrt_phat = -y[t]/phat + (1.0 - y[t])/(1.0 - phat);     
    REAL d_phat_wrt_z = exp(z)/(1.0+exp(z)) - pow(exp(z)/(1.0+exp(z)),2);
    F[1] = d_f_wrt_phat*d_phat_wrt_z;
    F[2] = d_f_wrt_phat*d_phat_wrt_z*x1[t];
    F[3] = d_f_wrt_phat*d_phat_wrt_z*x2[t];
    F[4] = d_f_wrt_phat*d_phat_wrt_z*x3[t];
    I += F*T(F);
  }
  I = I/n;
  return I;
}

INTEGER lnll::get_p() {return p;}

INTEGER lnll::get_n() {return n;}
