--- title: "variable selection: non-linear data" output: rmarkdown::html_vignette params: eval: true vignette: > %\VignetteIndexEntry{variable selection: non-linear data} %\VignetteEngine{knitr::rmarkdown} %\VignetteEncoding{UTF-8} --- ```{r, include = FALSE} knitr::opts_chunk$set( collapse = TRUE, comment = "#>" ) library(LBBNN) has_torch <- requireNamespace("torch", quietly = TRUE) && torch::torch_is_installed() ``` ## Generate data We again generate 1000 samples with 15 features, and make 6 of them relevant for the outcome. This is a more complicated problem than the previous, as we include some non-linear effects. This time we transform the outcome into a binary variable. ```{r, eval = has_torch} i <- 1000 j <- 15 set.seed(42) torch::torch_manual_seed(42) X_nl <- matrix(runif(i * j, 0, 0.5), ncol = j) y_nl <- (- 3 + 0.1 * log(abs(X_nl[, 1])) + 3 * cos(X_nl[, 2]) + 2 * X_nl[, 3] * X_nl[, 4] + X_nl[, 5] - X_nl[, 6] ** 2 + rnorm(i, sd = 0.1)) y <- c() # change y to 0 and 1 y[y_nl > median(y_nl)] <- 1 y[y_nl <= median(y_nl)] <- 0 sim_data_nl <- as.data.frame(X_nl) sim_data_nl <- cbind(sim_data_nl, y) loaders_nl <- get_dataloaders(sim_data_nl, train_proportion = 0.9, train_batch_size = 450, test_batch_size = 100, standardize = FALSE) train_loader_nl <- loaders_nl$train_loader test_loader_nl <- loaders_nl$test_loader ``` ## Define hyperparameters and the model object We use the same architecture as in the example with linear data. For this example, we use normalizing flows in the variational distribution. ```{r, eval = has_torch} problem <- "binary classification" sizes <- c(j, 5, 5, 1) incl_priors <- c(0.5, 0.5, 0.5) stds <- c(1, 1, 1) incl_inits <- 'polarized' device <- "cpu" model_nl <- lbbnn_net(problem_type = problem, sizes = sizes, prior = incl_priors, inclusion_inits = incl_inits, input_skip = TRUE, std = stds, flow = TRUE, dims = c(10, 10, 10), device = device, bias_inclusion_prob = FALSE) ``` ## Train and validate the model ```{r, eval = has_torch} train_lbbnn(epochs = 20, LBBNN = model_nl, lr = 0.2, train_dl = train_loader_nl, device = device, verbose = FALSE) validate_lbbnn(LBBNN = model_nl, num_samples = 2, test_dl = test_loader_nl, device = device) ``` ## Check the global explanations ```{r,,fig.width=6, fig.height=6, eval = has_torch} plot(model_nl, type = "global", vertex_size = 7, edge_width = 0.4, label_size = 0.4) ``` All the relevant features are included.