Differentiable tree-based models for tabular data.
Documentation | CI Status | DOI |
---|---|---|
] add NeuroTreeModels
⚠ Compatible with Julia >= v1.10
A model configuration is defined with on of the constructor:
using NeuroTreeModels, DataFrames
config = NeuroTreeRegressor(
loss = :mse,
nrounds = 10,
num_trees = 16,
depth = 5,
device = :cpu
)
For training on GPU, use device=:gpu
in the constructor, and optionally gpuID=0
to target a specific a device.
Building and training a model according to the above config
is done with NeuroTreeModels.fit.
See the docs for additional features, notably early stopping support through the tracking of an evaluation metric on evaluation data.
nobs, nfeats = 1_000, 5
dtrain = DataFrame(randn(nobs, nfeats), :auto)
dtrain.y = rand(nobs)
feature_names, target_name = names(dtrain, r"x"), "y"
m = NeuroTreeModels.fit(config, dtrain; feature_names, target_name)
p = m(dtrain)
p = m(dtrain; device=:gpu)
NeuroTreeModels.jl supports the MLJ Interface.
using MLJBase, NeuroTreeModels
m = NeuroTreeRegressor(depth=5, nrounds=10)
X, y = @load_boston
mach = machine(m, X, y) |> fit!
p = predict(mach, X)
Benchmarking against prominent ML libraries for tabular data is performed at MLBenchmarks.jl.