#include "TmW.h"
#include <Rcpp.h>
//#include <RcppArmadillo.h>

using namespace Rcpp;
using namespace arma;
// [[Rcpp::depends(RcppArmadillo)]]
// [[Rcpp::export]]

SEXP TmW_Gauss_cpp(SEXP x) {
  mat Z = as<arma::mat>(x);
  int p = Z.n_cols;
  int n = Z.n_rows;
  
  double p1 = 0;
  arma::vec p2_partial = arma::zeros<arma::vec>(p);
  double temp = 0;
  
  for (int i = 0; i < n; ++i) {
    arma::mat Cik = arma::ones<arma::mat>(n, p);
    for (int j = 0; j < n; ++j) {
      if (i == j) continue;
      
      arma::rowvec Zij = Z.row(i) - Z.row(j);
      arma::rowvec Cij = exp(-arma::square(Zij));
      
      if (j < i) {
        p1 += arma::prod(Cij);
        p2_partial += Cij.t();
      }
      
      Cik.row(j) = Cij;
    }
    temp += arma::prod(arma::sum(Cik, 0));
  }
  
  p1 = (2 * p1 + n) / n;
  double p2 = arma::prod(2 * p2_partial + n) / std::pow(n, 2 * p - 1);
  double p3 = 2 * temp / std::pow(n, p);
  double W = p1 + p2 - p3;
  
  return Rcpp::List::create(Rcpp::Named("W") = W);
  /*return W;*/
}


SEXP TmW_Lap_cpp(SEXP x) {
  mat Z = as<arma::mat>(x);
  int p = Z.n_cols;
  int n = Z.n_rows;
  
  double p1 = 0;
  arma::vec p2_partial = arma::zeros<arma::vec>(p);
  double temp = 0;
  
  for (int i = 0; i < n; ++i) {
    arma::mat Cik = arma::ones<arma::mat>(n, p);
    for (int j = 0; j < n; ++j) {
      if (i == j) continue;
      
      arma::rowvec Zij = Z.row(i) - Z.row(j);
      arma::rowvec Cij = 1/(1+arma::square(Zij));
      
      if (j < i) {
        p1 += arma::prod(Cij);
        p2_partial += Cij.t();
      }
      
      Cik.row(j) = Cij;
    }
    temp += arma::prod(arma::sum(Cik, 0));
  }
  
  p1 = (2 * p1 + n) / n;
  double p2 = arma::prod(2 * p2_partial + n) / std::pow(n, 2 * p - 1);
  double p3 = 2 * temp / std::pow(n, p);
  double W = p1 + p2 - p3;
  
  return Rcpp::List::create(Rcpp::Named("W") = W);
}
