---
title: "Getting started with rkaf"
output: rmarkdown::html_vignette
vignette: >
  %\VignetteIndexEntry{Getting started with rkaf}
  %\VignetteEngine{knitr::rmarkdown}
  %\VignetteEncoding{UTF-8}
---

```{r, include = FALSE}
knitr::opts_chunk$set(
  collapse = TRUE,
  comment = "#>",
  message = FALSE,
  warning = FALSE
)

rkaf_torch_available <- tryCatch(
  {
    torch::torch_manual_seed(123)
    invisible(torch::torch_tensor(0))
    TRUE
  },
  error = function(e) FALSE
)

knitr::opts_chunk$set(eval = rkaf_torch_available)
```

```{r torch-unavailable, echo = FALSE, eval = !rkaf_torch_available}
knitr::asis_output(
  "> Torch/Lantern is not available in this build environment, so the code chunks are shown but not executed."
)
```

# Overview

`rkaf` provides Kolmogorov-Arnold Fourier Networks for R users through the
`torch` backend.

The package supports:

- regression
- binary classification
- multiclass classification
- formula and matrix interfaces
- mini-batch training
- validation splits
- early stopping
- automatic standardization
- best-model restoration
- KAF-specific diagnostics

This vignette gives a quick tour of the main workflow.

```{r setup}
library(rkaf)

set.seed(123)
torch::torch_manual_seed(123)
```

# Regression with the matrix interface

We first fit a KAF model to a synthetic one-dimensional function with both
low-frequency and high-frequency structure.

```{r regression-data}
x <- as.matrix(seq(-1, 1, length.out = 128))

y <- sin(8 * pi * x) +
  0.35 * cos(3 * pi * x) +
  0.15 * x^2
```

```{r regression-fit}
fit <- kaf_fit(
  x = x,
  y = y,
  hidden = c(256, 256),
  num_grids = 32,
  use_layernorm = FALSE,
  epochs = 1000,
  lr = 1e-3,
  standardize_x = FALSE,
  standardize_y = TRUE,
  fourier_init_scale = 5e-2,
  restore_best = TRUE,
  verbose = FALSE,
  seed = 123
)

fit
```

```{r regression-predict}
pred <- predict(fit, x)

head(data.frame(
  observed = round(as.numeric(y), 3),
  predicted = round(pred, 3)
))
```

```{r regression-plot, fig.width=7, fig.height=4}
plot(
  x,
  y,
  type = "l",
  lwd = 2,
  xlab = "x",
  ylab = "f(x)",
  main = "KAF regression fit"
)

lines(x, pred, lwd = 2, lty = 2)

legend(
  "topright",
  legend = c("Observed", "Predicted"),
  lty = c(1, 2),
  lwd = 2,
  bty = "n"
)
```

# Regression with the formula interface

For tabular data, `rkaf` also supports a formula interface.

```{r formula-regression}
fit_mtcars <- kaf_fit_formula(
  mpg ~ wt + hp + cyl,
  data = mtcars,
  hidden = c(32, 32),
  num_grids = 16,
  epochs = 200,
  verbose = FALSE,
  seed = 123
)

fit_mtcars
```

```{r formula-predict}
mtcars_pred <- predict(fit_mtcars, mtcars)

head(data.frame(
  observed = mtcars$mpg,
  predicted = round(mtcars_pred, 2)
))
```

# Binary classification

If the response is a factor with two classes, `rkaf` automatically treats the
problem as binary classification.

```{r binary-data}
df <- mtcars

df$high_mpg <- factor(
  ifelse(df$mpg > median(df$mpg), "yes", "no"),
  levels = c("no", "yes")
)
```

```{r binary-fit}
fit_binary <- kaf_fit_formula(
  high_mpg ~ wt + hp + cyl,
  data = df,
  hidden = c(32, 32),
  num_grids = 16,
  epochs = 200,
  verbose = FALSE,
  seed = 123
)

fit_binary
```

Predicted probabilities and classes:

```{r binary-prob}
prob_binary <- predict(fit_binary, df, type = "prob")
class_binary <- predict(fit_binary, df, type = "class")

head(data.frame(
  observed = df$high_mpg,
  prob_yes = round(prob_binary, 3),
  predicted = class_binary
))
```

Confusion matrix

```{r binary-confusion-matrix}
table(
  observed = df$high_mpg,
  predicted = class_binary
)
```

Raw logits:

```{r binary-link}
head(predict(fit_binary, df, type = "link"))
```

# Multiclass classification

If the response is a factor with more than two classes, `rkaf` fits a multiclass
classifier.

```{r multiclass-fit}
fit_iris <- kaf_fit_formula(
  Species ~ .,
  data = iris,
  hidden = c(32, 32),
  num_grids = 16,
  epochs = 300,
  verbose = FALSE,
  seed = 123
)

fit_iris
```

Confusion matrix

```{r multiclass-confusion-matrix}
class_iris <- predict(fit_iris, iris, type = "class")

table(
  observed = iris$Species,
  predicted = class_iris
)
```

Class probabilities:

```{r multiclass-prob}
head(round(predict(fit_iris, iris, type = "prob"), 3))
```

# Validation and early stopping

`kaf_fit()` supports validation splits, mini-batches, and early stopping.

```{r validation-fit, eval=FALSE}
fit_val <- kaf_fit(
  x = x,
  y = y,
  hidden = c(64, 64),
  num_grids = 16,
  use_layernorm = FALSE,
  epochs = 300,
  lr = 5e-4,
  batch_size = 64,
  validation_split = 0.2,
  patience = 100,
  restore_best = TRUE,
  verbose = FALSE,
  seed = 123
)

plot(fit_val)
```

The fitted object stores both `train_loss_history` and
`validation_loss_history`, so users can inspect training and validation behavior
directly.

# KAF diagnostics

The KAF architecture contains a base/GELU branch and a Fourier branch. The
package exposes helper functions to inspect the learned branch scales and
Fourier parameters.

```{r diagnostics-scales}
scales <- extract_kaf_scales(fit)

head(scales)
```

```{r diagnostics-fourier}
fourier_params <- extract_fourier_params(fit, layer = 1)

head(fourier_params)
```

# Low-level torch interface

Advanced users can use the low-level torch modules directly.

```{r low-level}
model <- nn_kaf(
  layers = c(4, 16, 16, 1),
  num_grids = 8
)

x_tensor <- torch::torch_randn(10, 4)
y_tensor <- model(x_tensor)

y_tensor$shape
```

# Summary

The standard workflow is:

```{r workflow, eval=FALSE}
fit <- kaf_fit_formula(
  y ~ .,
  data = df,
  hidden = c(64, 64),
  num_grids = 16,
  validation_split = 0.2,
  patience = 30
)

predict(fit, newdata)
plot(fit)
extract_kaf_scales(fit)
```

For classification, use:

```{r classification-workflow, eval=FALSE}
predict(fit, newdata, type = "prob")
predict(fit, newdata, type = "class")
```
