#include "libscl.h"
using namespace std;
using namespace scl;

#undef USE_POW

class nlregr : public nleqns_base {
private:
  INTEGER n;
  realmat theta;
  realmat y;
  realmat x;
public:
nlregr() : n(250000),theta(3,1,1.0),y(n,1),x(n,1)
{
  INT_32BIT seed = 740726;
  for (INTEGER i=1; i<=n; ++i) {
    x[i] = ran(seed);
    y[i] = theta[1] + theta[2]*exp(x[i]*theta[3]) + unsk(seed);
  }
} 
bool get_f(const realmat& theta, realmat& mse) 
{
  if (theta.get_rows()!=3) error("Error, nlregr, wrong dim for theta");
  if (mse.get_rows()!=1) mse.resize(1,1);
  REAL sum = 0.0;
  for (INTEGER i=1; i<=n; ++i) {
    REAL res = y[i] - theta[1] - theta[2]*exp(x[i]*theta[3]);
    #if defined USE_POW
      sum += pow(res,2);
    #else
      sum += res*res;
    #endif
  }
  mse[1] = sum/REAL(n);
  return true;
}
bool get_F(const realmat& theta, realmat& mse, realmat& dmse)
{
  if (theta.get_rows()!=3) error("Error, nlregr, wrong dim for theta");
  if (mse.get_rows()!=1) mse.resize(1,1);
  if (dmse.get_rows()!=1 || dmse.get_cols()!=3) dmse.resize(1,3);
  REAL sum = 0.0;
  realmat dsum(1,3,0.0);
  for (INTEGER i=1; i<=n; ++i) {
    REAL res = y[i] - theta[1] - theta[2]*exp(x[i]*theta[3]);
    #if defined USE_POW
      sum += pow(res,2);
    #else
      sum += res*res;
    #endif
    dsum[1] -= 2.0*res;
    dsum[2] -= 2.0*res*exp(x[i]*theta[3]);
    dsum[3] -= 2.0*res*theta[2]*exp(x[i]*theta[3])*x[i];
  }
  mse[1] = sum/REAL(n);
  dmse = dsum/n;
  return true;
}
};

int main(int argc, char** argp, char** envp)
{
  #if defined USE_POW
    warn("MSE computed using pow function");
  #else
    warn("MSE computed using multiplication");
  #endif
  nlregr reg;
  nlopt minimizer(reg);
  ofstream os("nlregr.out");
  minimizer.set_output(true, &os);
  minimizer.set_check_derivatives(true);
  minimizer.set_warning_messages(true);
  realmat theta_start(3,1,0.5); realmat theta_stop;
  if (minimizer.minimize(theta_start, theta_stop)) {
    os << starbox("/The Answer!//") << theta_stop << '\n'; 
    os << starbox("/Status//") << '\n';
    os << "\t termination_code = " << minimizer.get_termination_code() << '\n';
    os << "\t iter_count = " << minimizer.get_iter_count() << '\n';
    os << "\t H = " << minimizer.get_H_matrix() << '\n';
  return 0;
  }
  else {
    os << starbox("/Failure/Check your derivatives!//"); 
    return 1;
  }
}
