## ----include = FALSE----------------------------------------------------------
knitr::opts_chunk$set(
  collapse = TRUE,
  comment = "#>",
  message = FALSE,
  warning = FALSE
)

rkaf_torch_available <- tryCatch(
  {
    torch::torch_manual_seed(123)
    invisible(torch::torch_tensor(0))
    TRUE
  },
  error = function(e) FALSE
)

knitr::opts_chunk$set(eval = rkaf_torch_available)

## ----torch-unavailable, echo = FALSE, eval = !rkaf_torch_available------------
# knitr::asis_output(
#   "> Torch/Lantern is not available in this build environment, so the code chunks are shown but not executed."
# )

## ----setup--------------------------------------------------------------------
library(rkaf)

set.seed(123)
torch::torch_manual_seed(123)

## ----regression-data----------------------------------------------------------
x <- as.matrix(seq(-1, 1, length.out = 128))

y <- sin(8 * pi * x) +
  0.35 * cos(3 * pi * x) +
  0.15 * x^2

## ----regression-fit-----------------------------------------------------------
fit <- kaf_fit(
  x = x,
  y = y,
  hidden = c(256, 256),
  num_grids = 32,
  use_layernorm = FALSE,
  epochs = 1000,
  lr = 1e-3,
  standardize_x = FALSE,
  standardize_y = TRUE,
  fourier_init_scale = 5e-2,
  restore_best = TRUE,
  verbose = FALSE,
  seed = 123
)

fit

## ----regression-predict-------------------------------------------------------
pred <- predict(fit, x)

head(data.frame(
  observed = round(as.numeric(y), 3),
  predicted = round(pred, 3)
))

## ----regression-plot, fig.width=7, fig.height=4-------------------------------
plot(
  x,
  y,
  type = "l",
  lwd = 2,
  xlab = "x",
  ylab = "f(x)",
  main = "KAF regression fit"
)

lines(x, pred, lwd = 2, lty = 2)

legend(
  "topright",
  legend = c("Observed", "Predicted"),
  lty = c(1, 2),
  lwd = 2,
  bty = "n"
)

## ----formula-regression-------------------------------------------------------
fit_mtcars <- kaf_fit_formula(
  mpg ~ wt + hp + cyl,
  data = mtcars,
  hidden = c(32, 32),
  num_grids = 16,
  epochs = 200,
  verbose = FALSE,
  seed = 123
)

fit_mtcars

## ----formula-predict----------------------------------------------------------
mtcars_pred <- predict(fit_mtcars, mtcars)

head(data.frame(
  observed = mtcars$mpg,
  predicted = round(mtcars_pred, 2)
))

## ----binary-data--------------------------------------------------------------
df <- mtcars

df$high_mpg <- factor(
  ifelse(df$mpg > median(df$mpg), "yes", "no"),
  levels = c("no", "yes")
)

## ----binary-fit---------------------------------------------------------------
fit_binary <- kaf_fit_formula(
  high_mpg ~ wt + hp + cyl,
  data = df,
  hidden = c(32, 32),
  num_grids = 16,
  epochs = 200,
  verbose = FALSE,
  seed = 123
)

fit_binary

## ----binary-prob--------------------------------------------------------------
prob_binary <- predict(fit_binary, df, type = "prob")
class_binary <- predict(fit_binary, df, type = "class")

head(data.frame(
  observed = df$high_mpg,
  prob_yes = round(prob_binary, 3),
  predicted = class_binary
))

## ----binary-confusion-matrix--------------------------------------------------
table(
  observed = df$high_mpg,
  predicted = class_binary
)

## ----binary-link--------------------------------------------------------------
head(predict(fit_binary, df, type = "link"))

## ----multiclass-fit-----------------------------------------------------------
fit_iris <- kaf_fit_formula(
  Species ~ .,
  data = iris,
  hidden = c(32, 32),
  num_grids = 16,
  epochs = 300,
  verbose = FALSE,
  seed = 123
)

fit_iris

## ----multiclass-confusion-matrix----------------------------------------------
class_iris <- predict(fit_iris, iris, type = "class")

table(
  observed = iris$Species,
  predicted = class_iris
)

## ----multiclass-prob----------------------------------------------------------
head(round(predict(fit_iris, iris, type = "prob"), 3))

## ----validation-fit, eval=FALSE-----------------------------------------------
# fit_val <- kaf_fit(
#   x = x,
#   y = y,
#   hidden = c(64, 64),
#   num_grids = 16,
#   use_layernorm = FALSE,
#   epochs = 300,
#   lr = 5e-4,
#   batch_size = 64,
#   validation_split = 0.2,
#   patience = 100,
#   restore_best = TRUE,
#   verbose = FALSE,
#   seed = 123
# )
# 
# plot(fit_val)

## ----diagnostics-scales-------------------------------------------------------
scales <- extract_kaf_scales(fit)

head(scales)

## ----diagnostics-fourier------------------------------------------------------
fourier_params <- extract_fourier_params(fit, layer = 1)

head(fourier_params)

## ----low-level----------------------------------------------------------------
model <- nn_kaf(
  layers = c(4, 16, 16, 1),
  num_grids = 8
)

x_tensor <- torch::torch_randn(10, 4)
y_tensor <- model(x_tensor)

y_tensor$shape

## ----workflow, eval=FALSE-----------------------------------------------------
# fit <- kaf_fit_formula(
#   y ~ .,
#   data = df,
#   hidden = c(64, 64),
#   num_grids = 16,
#   validation_split = 0.2,
#   patience = 30
# )
# 
# predict(fit, newdata)
# plot(fit)
# extract_kaf_scales(fit)

## ----classification-workflow, eval=FALSE--------------------------------------
# predict(fit, newdata, type = "prob")
# predict(fit, newdata, type = "class")

