#' An Automatic Method for the Analysis of Experiments using Hierarchical Garrote
#' 
#' `HiGarrote()` provides an automatic method for analyzing experimental data. 
#' This function applies the nonnegative garrote method to select important effects while preserving their hierarchical structures.
#' It first estimates regression parameters using generalized ridge regression, where the ridge parameters are derived from a Gaussian process prior placed on the input-output relationship. 
#' Subsequently, the initial estimates will be used in the nonnegative garrote for effects selection.
#' 
#' @param D An \eqn{n \times p} data frame for the unreplicated design matrix, where \eqn{n} is the run size and \eqn{p} is the number of factors.
#' @param y A vector for the responses corresponding to \code{D}. For replicated experiments, \code{y} should be an \eqn{n \times r} matrix, where \eqn{r} is the number of replicates.
#' @param heredity Specifies the heredity principles to be used. Supported options are \code{"weak"} and \code{"strong"}. The default is \code{"weak"}.
#' @param quali_id A vector indexing qualitative factors. 
#' Qualitative factors are coded using Helmert coding. 
#' Different coding systems are allowed by specifying \code{quali_sum_idx}, \code{user_def_coding}, \code{user_def_coding_idx}.
#' @param quanti_id A vector indexing quantitative factors. Quantitative factors are coded using orthogonal polynomial coding.
#' @param quali_sum_idx Optional. Indicating which qualitative factors should use sum coding (\code{contr.sum()}).
#' @param user_def_coding Optional. A list of user-defined orthogonal coding systems.
#' Each element must be an orthogonal contrast matrix.
#' @param user_def_coding_idx Optional.
#' A list of indices specifying which qualitative factors should use the corresponding coding systems provided in \code{user_def_coding}.
#' @param model_type Integer indicating the type of model to construct.
#' \describe{
#' \item{model_type = 1}{The model matrix includes all the main effects of qualitative factors, 
#' the first two main effects (linear and quadratic) of all the quantitative factors, 
#' and all the two-factor interactions generated by those main effects.
#' }
#' \item{model_type = 2}{The model matrix includes all the main effects of qualitative factors,
#' the linear effects of all the quantitative factors, 
#' all the two-factor interactions generated by those main effects, 
#' and the quadratic effects of all the quantitative factors.}
#' \item{model_type = 3}{The model matrix includes all the main effects of qualitative factors 
#' and the linear effects of all the quantitative factors.}
#' }
#' The default is \code{model_type = 1}.
#' 
#' @returns The function returns a list with:
#' \describe{
#' \item{\code{nng_estimate}}{A vector for the nonnegative garrote estimates of the identified effects.}
#' \item{\code{U}}{A model matrix of \code{D}.}
#' \item{\code{pred_info}}{A list containing information needed for future predictions.}
#' }
#' 
#' @export
#' @examples
#' # Cast fatigue experiment
#' data(cast_fatigue)
#' X <- cast_fatigue[,1:7]
#' y <- cast_fatigue[,8]
#' fit_Hi <- HiGarrote::HiGarrote(X, y)
#' fit_Hi$nng_estimate
#' 
#' # Blood glucose experiment
#' data(blood_glucose)
#' X <- blood_glucose[,1:8]
#' y <- blood_glucose[,9]
#' fit_Hi <- HiGarrote::HiGarrote(X, y, quanti_id = 2:8) 
#' fit_Hi$nng_estimate
#' 
#' \donttest{
#' # Router bit experiment --- Use default Helmert coding
#' data(router_bit)
#' X <- router_bit[, 1:9]
#' y <- router_bit[,10]
#' fit_Hi <- HiGarrote::HiGarrote(X, y, quali_id = c(4,5))
#' fit_Hi$nng_estimate
#' 
#' # Router bit experiment --- Use sum coding
#' fit_Hi <- HiGarrote::HiGarrote(X, y, quali_id = c(4,5), quali_sum_idx = c(4,5))
#' fit_Hi$nng_estimate
#' 
#' # Router bit experiment --- Use user-defined coding system for qualitative factors
#' fit_Hi <- HiGarrote::HiGarrote(X, y, quali_id = c(4,5),
#'  user_def_coding = list(matrix(c(-1,-1,1,1,1,-1,-1,1,-1,1,-1,1), ncol = 3)),
#'  user_def_coding_idx = list(c(4,5)))
#' fit_Hi$nng_estimate
#' 
#' # Resin experiment --- Use model_type = 2
#' data(resin)
#' X <- resin[,1:9]
#' y <- log(resin$Impurity)
#' fit_Hi <- HiGarrote::HiGarrote(X, y, quanti_id = c(1:9), model_type = 2)
#' fit_Hi$nng_estimate
#' 
#' # Epoxy experiment --- Use model_type = 3
#' data(epoxy)
#' X <- epoxy[,1:23]
#' y <- epoxy[,24]
#' fit_Hi <- HiGarrote::HiGarrote(X, y, model_type = 3)
#' fit_Hi$nng_estimate
#' 
#' # Experiments with replicates
#' # Generate simulated data
#' data(cast_fatigue)
#' X <- cast_fatigue[,1:7]
#' U <- data.frame(model.matrix(~.^2, X)[,-1])
#' error <- matrix(rnorm(24), ncol = 2) # two replicates for each run
#' y <- 20*U$A + 10*U$A.B + 5*U$A.C + error
#' fit_Hi <- HiGarrote::HiGarrote(X, y)
#' fit_Hi$nng_estimate
#' }
#' 
#' @references
#' Yu, W. Y. and Joseph, V. R. (2025). Automated Analysis of Experiments using Hierarchical Garrote. 
#' \emph{Journal of Quality Technology}, 1-15. \doi{10.1080/00224065.2025.2513508}.

HiGarrote <- function(D, y, heredity = "weak",
                      quali_id = NULL, quanti_id = NULL,
                      quali_sum_idx = NULL,
                      user_def_coding = NULL, user_def_coding_idx = NULL,
                      model_type = 1) {
  # Obtain model matrix and corresponding info
  info <- model_matrix(D, quali_id, quanti_id,
                       quali_sum_idx,
                       user_def_coding, user_def_coding_idx,
                       model_type)
  D <- info$D
  n <- info$n
  p <- info$p
  uni_level <- info$uni_level
  mi <- info$mi
  two_level_id <- info$two_level_id
  quali_id <- info$quali_id
  me_num <- info$me_num
  U_j_list <- info$U_j_list
  U <- info$U
  effects_name <- colnames(U)
  
  # Retain model matrix information for prediction
  pred_model_matrix_info <- list(uni_level = uni_level, two_level_id = two_level_id,
                                 quali_id = quali_id, quanti_id = quanti_id,
                                 me_num = me_num,
                                 contrast = info$contrast,
                                 model_type = model_type,
                                 initial_colnames = info$initial_colnames,
                                 n = n, quanti_D_norm = info$quanti_D_norm)
  
  # Preprocess response
  y <- as.matrix(y)
  y_mean <- mean(y)
  replicate <- ncol(y)
  run <- rep(1:n, replicate)
  y_s2 <- y
  y <- sapply(split(y,run),mean)
  ori_y <- y
  y_sd <- sd(y)
  y <- scale(y)
  s2 <- 0.0
  if(replicate != 1) {
    y_s2 <- y_s2/y_sd
    s2 <- mean(sapply(split(y_s2,run),var))
  }
  
  # Generate h_list
  h_list <- lapply(1:p, function(i) {
    h_dist(D[,i], (i%in%two_level_id), (i%in%quali_id))
  })
  h_list_mat <- rlist::list.flatten(h_list)
  
  # Generate h_j_list
  h_j_list <- h_j_cpp(p, uni_level, U_j_list, two_level_id, quali_id)
  h_j_list <- unlist(h_j_list, recursive = FALSE)
  rho_len <- ifelse(sapply(h_j_list, is.list) == FALSE, 1, lengths(h_j_list))
  
  # Optimize for rho_lambda
  P <- sum(rho_len)
  ini_point <- MaxPro::MaxProLHD(P+1, P+1)
  ini_point0 <- ini_point$Design
  ini_point0 <- scales::rescale(ini_point0, to = c(0.01,0.99))
  rho_lambda_list <- rho_lambda_optim(ini_point0, h_list_mat, n, replicate, y, 0.01, 0.99)
  rho_lambda_list <- unlist(rho_lambda_list, recursive = FALSE)
  rho_lambda_obj_value <- -unlist(purrr::map(rho_lambda_list, "objective"))
  rho_lambda <- rho_lambda_list[[which.max(rho_lambda_obj_value)]]$solution
  rho <- rho_lambda[1:P]
  lambda <- rho_lambda[P+1]
  rho_list <- purrr::map2(cumsum(c(0, rho_len[-length(rho_len)])) + 1, cumsum(rho_len), ~ rho[.x:.y])
  
  # Generate R
  initialize_BETA_instance(h_j_list, p, rho_list, mi)
  r_j <- r_j_cpp_R(U_j_list, me_num)
  r_j <- unlist(r_j)
  me_idx <- which(!stringr::str_detect(effects_name, ":")) # main effects idx
  hoe_idx <- which(stringr::str_detect(effects_name, ":")) # higher-order effects idx
  names(r_j) <- effects_name[me_idx]
  R <- c(rep(1, length(effects_name)))
  names(R) <- effects_name
  # main effects
  R[intersect(names(R), names(r_j))] <- r_j[intersect(names(R), names(r_j))]
  # higher order effects
  if(length(hoe_idx) != 0){
    hoe_names <- stringr::str_split(effects_name[hoe_idx], ":")
    R_hoe <- lapply(hoe_names, function(i){
      prod(r_j[i])
    })
    R_hoe <- unlist(R_hoe)
    R[hoe_idx] <- R[hoe_idx]*R_hoe
  }
  
  # Heredity
  if(heredity == "strong") {
    A1 <- gstrong(U)
  } else {
    A1 <- gweak(U)
  }
  
  beta_ele <- beta_ele_cpp_R(U, R, lambda, replicate, n, y)
  beta_ele$Dmat <- (beta_ele$Dmat + t(beta_ele$Dmat))/2
  if(!matrixcalc::is.positive.definite(beta_ele$Dmat)) {
    D_nng <- Matrix::nearPD(beta_ele$Dmat)
    beta_ele$Dmat <- as.matrix(D_nng$mat)
  }
  beta_nng <- beta_nng_cpp_R(beta_ele, replicate, n, y, A1, s2)
  beta_shrink = round(beta_nng, 6)
  names(beta_shrink) <- effects_name
  beta_shrink <- beta_shrink[which(beta_shrink!=0)]
  beta_shrink <- beta_shrink[order(abs(beta_shrink), decreasing  = TRUE)]
  beta_shrink <- beta_shrink*y_sd
  
  # Refit a linear regression model for prediction
  terms <- names(beta_shrink)
  form <- as.formula(paste("ori_y ~", paste(terms, collapse = " + ")))
  lmod <- lm(form, data.frame(U, ori_y))
  
  # Return
  pred_info <- list(pred_model_matrix_info = pred_model_matrix_info,
                    beta_nng = beta_nng,
                    y_mean = y_mean, y_sd = y_sd, lmod = lmod)
  return_list <- list(nng_estimate = beta_shrink, U = U, pred_info = pred_info)
  class(return_list) <- "HiGarrote"
  return(return_list)
}


#' Make Predictions from a "HiGarrote" Object
#' 
#' This function makes predictions from a linear model constructed using the important effects selected by \code{HiGarrote}.
#' 
#' @param object An HiGarrote object.
#' @param new_D A new design matrix where predictions are to be made.
#' @param ... Additional arguments passed to `predict`. Not used in this function.
#' 
#' @returns The function returns a list with:
#' \describe{
#' \item{\code{new_U}}{A model matrix of \code{new_D}.}
#' \item{\code{prediction_nng}}{Predictions for \code{new_D}. The coefficients of the predictive equation are based on nonnegative garrote estimates.}
#' \item{\code{prediction_lm}}{Predictions for \code{new_D}. The coefficients of the predictive equation are estimated via ordinary least squares.}
#' }
#' 
#' @export
#' @method predict HiGarrote
#' @examples
#' # Cast fatigue experiment
#' data(cast_fatigue)
#' X <- cast_fatigue[1:10,1:7]
#' y <- cast_fatigue[1:10,8]
#' fit_Hi <- HiGarrote::HiGarrote(X, y)
#' 
#' # make predictions
#' new_D <- cast_fatigue[11:12,1:7]
#' pred_Hi <- predict(fit_Hi, new_D)

predict.HiGarrote <- function(object, new_D, ...) {
  if (is.null(new_D)) stop("Must provide 'new_D' for prediction.")
  new_D <- data.frame(new_D)
  
  # Extract information for creating model matrix
  two_level_id <- object$pred_info$pred_model_matrix_info$two_level_id
  quali_id <- object$pred_info$pred_model_matrix_info$quali_id
  quanti_id <- object$pred_info$pred_model_matrix_info$quanti_id
  me_num <- object$pred_info$pred_model_matrix_info$me_num
  uni_level <- object$pred_info$pred_model_matrix_info$uni_level
  contrast <- object$pred_info$pred_model_matrix_info$contrast
  model_type <- object$pred_info$pred_model_matrix_info$model_type
  initial_colnames <- object$pred_info$pred_model_matrix_info$initial_colnames
  lmod <- object$pred_info$lmod
  n <- object$pred_info$pred_model_matrix_info$n
  quanti_D_norm <- object$pred_info$pred_model_matrix_info$quanti_D_norm
  
  # Error handling
  a <- c(two_level_id, quali_id)
  if(!is.null(a)) {
    for(i in a) {
      if(all(unique(new_D[,i]) %in% uni_level[[i]]) == FALSE) {stop("Levels of qualitative factors in new_D must be included in D.")}
    }
  }
  if(!all(initial_colnames == colnames(new_D))) {
    stop("Column names of new_D must match those of D in object exactly.")
  }
  
  # Get contrast for qualitative and quantitative factors
  for(i in quali_id) {
    new_D[,i] <- factor(new_D[,i], levels = uni_level[[i]])
    contrasts(new_D[,i]) <- contrast[[i]]
  }
  for(i in quanti_id) {
    new_D[,i] <- predict(contrast[[i]], newdata = new_D[,i])
    new_D[,i] <- sweep(new_D[,i][, (1:me_num[i])], 2, sqrt(n/quanti_D_norm[[i]]), "*")
  }
  
  # Get model matrix for new_D
  if(model_type == 2) { # full quadratic model
    D1 <- new_D
    if(!is.null(quanti_id)) {
      for(i in quanti_id){
        D1[,i] <- new_D[,i][,1] # linear effect
      }
    }
    U1 <- model.matrix(~.^2, D1)
    D2 <- data.frame(matrix(nrow = nrow(new_D), ncol = 0))
    if(!is.null(quanti_id)) {
      for(i in quanti_id){
        D2 <- cbind(D2, new_D[,i][,2])
        colnames(D2)[ncol(D2)] <- paste0(initial_colnames[i], ":", initial_colnames[i])
      }
    }
    U2 <- model.matrix(~., D2)
    U2 <- U2[,-1]
    colnames(U2) <- gsub("`", "", colnames(U2))
    U <- cbind(U1, U2)
    
  } else if(model_type == 3) { # main effects model
    if(!is.null(quanti_id)) {
      for(i in quanti_id){
        colnames(new_D[,i]) <- c(".1", paste0(".1", ":", paste0(initial_colnames[i], ".1")))
      }
    }
    U <- model.matrix(~., new_D)
  } else {
    if(!is.null(quanti_id)) {
      for(i in quanti_id){
        colnames(new_D[,i]) <- c(".1", ".2")
      }
    }
    U <- model.matrix(~.^2, new_D)
  }
  U <- U[,-1]
  
  y_mean <- object$pred_info$y_mean
  y_sd <- object$pred_info$y_sd
  beta_nng <- object$pred_info$beta_nng
  beta_nng <- beta_nng*y_sd
  pred_nng <- y_mean + as.numeric(U%*%beta_nng)
  
  # refit linear model prediction
  pred_lm <- as.numeric(predict(lmod, data.frame(U)))
  
  result <- list(new_U = U, prediction_nng = pred_nng, prediction_lm = pred_lm)
  return(result)
}



#' Nonnegative Garrote Method with Hierarchical Structures
#'
#' `nnGarrote()` implements the nonnegative garrote method, as described in Yuan et al. (2009), for selecting important variables while preserving hierarchical structures.
#' The method begins by obtaining the least squares estimates of the regression parameters under a linear model.
#' These initial estimates are then used in the nonnegative garrote to perform variable selection.
#' Note that this method is suitable only when the number of observations is much larger than the number of variables, ensuring that the least squares estimation remains reliable.
#'
#' @param X An \eqn{n \times p} input matrix, where \eqn{n} is the number of data and \eqn{p} is the number of variables.
#' @param y A vector for the responses.
#' @param heredity Specifies the heredity principles to be used. Supported options are \code{"weak"} and \code{"strong"}. The default is \code{"weak"}.
#' @param model_type Integer indicating the type of model to construct.
#' \describe{
#' \item{model_type = 1}{The model matrix includes linear effects, two-factor interactions, and quadratic effects.}
#' \item{model_type = 2}{The model matrix includes linear effects and two-factor interactions.}
#' }
#' The default is \code{model_type = 1}.
#' 
#' @returns The function returns a list with:
#' \describe{
#' \item{\code{nng_estimate}}{A vector for the nonnegative garrote estimates of the identified variables.}
#' \item{\code{U}}{A scaled model matrix.}
#' \item{\code{pred_info}}{A list of information required for further prediction.}
#' }
#'
#' @export
#' @examples
#' # Generate data
#' x1 <- runif(100)
#' x2 <- runif(100)
#' x3 <- runif(100)
#' error <- rnorm(100)
#' X <- data.frame(x1, x2, x3)
#' U <- model.matrix(~. + x1:x2 + x1:x3 + x2:x3 + I(x1^2) + I(x2^2) + I(x3^2) - 1, X)
#' U <- data.frame(scale(U))
#' colnames(U) <- c("x1", "x2", "x3", "x1:x1", "x2:x2", "x3:x3", "x1:x2", "x1:x3", "x2:x3")
#' y <- 3 + 3*U$x1 + 3*U$`x1:x1` + 3*U$`x1:x2`+ 3*U$`x1:x3` + error
#' 
#' # Fit nnGarrote
#' fit_nng <- HiGarrote::nnGarrote(X, y)
#' fit_nng$nng_estimate
#'
#' @references
#' Yuan, M., Joseph, V. R., and Zou H. (2009). Structured Variable Selection and Estimation. 
#' \emph{The Annals of Applied Statistics}, 3(4), 1738–1757. \doi{10.1214/09-AOAS254}.

nnGarrote <- function(X, y, heredity = "weak", model_type = 1) {
  # Preprocess
  X <- data.frame(X)
  ori_y <- y
  y_mean <- mean(y)
  y_sd <- sd(y)
  y <- scale(y)
  n <- nrow(U)
  P <- ncol(U)
  if(n < P) {stop("n must be larger than P.")}
  
  # Get model matrix
  if(model_type == 2) { # main effects + two-factor interactions model
    U <- model.matrix(~.^2, X)
  } else { # full quadratic model
    vars <- colnames(X)
    form <- as.formula(
      paste("~ (", paste(vars, collapse = " + "), ")^2 + ",
            paste0("I(", vars, "^2)", collapse = " + "))
    )
    U <- model.matrix(form, X)
    colnames(U) <- gsub("^I\\(([^)]+)\\^2\\)$", "\\1:\\1", colnames(U))
  }
  U <- U[,-1]
  U_mean <- apply(U, 2, mean)
  U_sd <- apply(U, 2, sd)
  effects_name <- colnames(U)
  U <- scale(U)
  
  # Least squares estimates
  dat <- data.frame(U, y)
  lmod <- lm(y~.-1, data = dat)
  beta <- lmod$coefficients
  names(beta) <- effects_name

  # Heredity
  if(heredity == "strong") {
    A1 <- gstrong(U)
  } else {
    A1 <- gweak(U)
  }
  
  B=diag(beta)
  Z=U%*%B
  D.mat=t(Z)%*%Z
  d=t(Z)%*%y
  D.mat <- (D.mat + t(D.mat))/2
  
  if(!matrixcalc::is.positive.definite(D.mat)) {
    D.mat <- Matrix::nearPD(D.mat)
    D.mat <- as.matrix(D.mat$mat)
  }

  M=seq(0.1,length(beta),length=100)
  gcv=numeric(100)
  for(i in 1:100){
    b0=c(-M[i],rep(0,dim(A1)[2]-1))
    coef_nng = quadprog::solve.QP(D.mat, d, A1, b0)$sol
    e=y-Z%*%coef_nng
    gcv[i]=sum(e^2)/(n*(1-M[i]/n)^2)
  }
  M=M[which.min(gcv)]
  b0=c(-M,rep(0,dim(A1)[2]-1))
  coef_nng=round(quadprog::solve.QP(D.mat, d, A1, b0)$sol,10)
  beta_nng=B%*%coef_nng

  beta_shrink = round(beta_nng, 6)
  names(beta_shrink) <- effects_name
  beta_shrink <- beta_shrink[which(beta_shrink!=0)]
  beta_shrink <- beta_shrink[order(abs(beta_shrink), decreasing  = TRUE)]
  beta_shrink <- beta_shrink*y_sd
  beta_scale <- beta_nng*y_sd
  beta_adj <- sum((U_mean/U_sd)*beta_scale)
  beta_scale <- beta_scale/U_sd
  
  # Refit a linear regression model for prediction
  terms <- names(beta_shrink)
  terms <- sapply(terms, function(term) {
    vars <- strsplit(term, ":")[[1]]
    if (length(vars) == 2 && vars[1] == vars[2]) {
      return(paste0("I(", vars[1], "^2)"))
    } else {
      return(gsub(":", "*", term))  # convert regular interactions
    }
  })
  form <- as.formula(paste("ori_y ~", paste(terms, collapse = " + ")))
  lmod <- lm(form, data.frame(X, ori_y))
  
  # return
  pred_info <- list(model_type = model_type, 
                    beta_scale = beta_scale, beta_adj = beta_adj,
                    y_sd = y_sd, y_mean = y_mean, lmod = lmod)
  
  result <- list(nng_estimate = beta_shrink, U = U, 
                 pred_info = pred_info)
  class(result) <- "nnGarrote"
  return(result)
}

#' Make Predictions from a "nnGarrote" Object
#'
#' This function makes predictions from a linear model constructed using the important effects selected by \code{nnGarrote}.
#'
#' @param object An \code{nnGarrote} object.
#' @param new_X A new input matrix where predictions are to be made.
#' @param ... Additional arguments passed to `predict`. Not used in this function.
#' 
#' @returns
#' The function returns a list with:
#' \describe{
#' \item{\code{new_U}}{A model matrix of \code{new_X}.}
#' \item{\code{prediction_nng}}{Predictions for \code{new_X}. The coefficients of the predictive equation are based on nonnegative garrote estimates.}
#' \item{\code{prediction_lm}}{Predictions for \code{new_X}. The coefficients of the predictive equation are estimated via ordinary least squares.}
#' }
#' 
#' @export
#' @method predict nnGarrote
#' @examples
#' # Generate data
#' x1 <- runif(100)
#' x2 <- runif(100)
#' x3 <- runif(100)
#' error <- rnorm(100)
#' X <- data.frame(x1, x2, x3)
#' U <- model.matrix(~. + x1:x2 + x1:x3 + x2:x3 + I(x1^2) + I(x2^2) + I(x3^2) - 1, X)
#' U <- data.frame(scale(U))
#' colnames(U) <- c("x1", "x2", "x3", "x1:x1", "x2:x2", "x3:x3", "x1:x2", "x1:x3", "x2:x3")
#' y <- 3 + 3*U$x1 + 3*U$`x1:x1` + 3*U$`x1:x2`+ 3*U$`x1:x3` + error
#' 
#' # training and testing set
#' train_idx <- sample(1:100, 80)
#' X_train <- X[train_idx,]
#' y_train <- y[train_idx]
#' X_test <- X[-train_idx,]
#' y_test <- y[-train_idx]
#' 
#' # fit nnGarrote
#' fit_nng <- HiGarrote::nnGarrote(X_train, y_train)
#' 
#' # predict
#' pred_nng <- predict(fit_nng, X_test)

predict.nnGarrote <- function(object, new_X, ...) {
  if (is.null(new_X)) stop("Must provide 'new_X' for prediction.")
  new_X <- data.frame(new_X)
  
  # Extract information for prediction
  model_type <- object$pred_info$model_type
  beta_scale <- object$pred_info$beta_scale
  beta_adj <- object$pred_info$beta_adj
  y_mean <- object$pred_info$y_mean
  y_sd <- object$pred_info$y_sd
  U <- object$U
  lmod <- object$pred_info$lmod
  
  # Get model matrix
  if(model_type == 2) { # main effects + two-factor interactions model
    new_U <- model.matrix(~.^2, new_X)
  } else { # full quadratic model
    vars <- colnames(new_X)
    form <- as.formula(
      paste("~ (", paste(vars, collapse = " + "), ")^2 + ",
            paste0("I(", vars, "^2)", collapse = " + "))
    )
    new_U <- model.matrix(form, new_X)
  }
  new_U <- new_U[,-1]
  if(ncol(new_U) != ncol(U)) {stop("Number of columns in new_X must be the same as that in X.")}
  colnames(new_U) <- colnames(U)
  
  # Prediction
  new_U <- as.matrix(new_U)
  pred_nng <- y_mean + as.numeric(new_U%*%beta_scale) - beta_adj
  pred_lm <- as.numeric(predict(lmod, new_X))
  
  result <- list(new_U = new_U, prediction_nng = pred_nng, prediction_lm = pred_lm)
  return(result)
}

