Skip to content

Version 2.0

Latest
Compare
Choose a tag to compare
@Tom94 Tom94 released this 08 Jul 11:27
· 17 commits to master since this release
3c7931c

tiny-cuda-nn now comes with a just-in-time (JIT) compilation mode that fuses encodings, neural networks, loss functions, and even backpropagation into single CUDA kernels. This leads to 1.5x-2.5x faster inference and training out of the box and can be enabled with a single line of code, see the "Automatic JIT" section below.

Even larger speed-ups are possible when applications integrate tightly with tiny-cuda-nn's new JIT compiler. For example, Instant NGP achieves a 5x speedup by fusing the entire NeRF ray marcher into a single kernel. See the "Direct JIT integration" section for details on how to accomplish this.

Automatic JIT

To enable JIT compilation mode, set the jit_fusion property of your model to true. All future uses of the model, whether inference or training, will then use JIT mode. Note that if there is an error during JIT compilation, a warning will be emitted and JIT compilation mode automatically turned off. Your code will still run using the tiny-cuda-nn 1.X code path.

auto model = tcnn::create_from_config(...);
model->set_jit_fusion(tcnn::supports_jit_fusion()); // Enable JIT if the system supports it

Note: If your model has very large hash grids (~20 million+ parameters) or MLPs (layer sizes larger than 128 neurons), or when your GPU is an RTX 3000 series or earlier, JIT fusion can slow down training. Rarely inference, too. It this case, it is recommended to try enabling JIT fusion separately for training and inference to measure whether it is faster.

JIT fusion can also be enabled via the PyTorch bindings, but the speed-up will be lower, particularly during training. This is because, in PyTorch, the JIT compiler does not have access to the whole compute graph and can therefore fuse and optimize less.

import tinycudann as tcnn

model = tcnn.NetworkWithInputEncoding(...) # Or any other tcnn model
model.jit_fusion = tcnn.supports_jit_fusion() # Enable JIT if the system supports it

Direct JIT integration

tiny-cuda-nn 2.0's JIT compiler works by converting a given tiny-cuda-nn model to a CUDA device function and then compiling it into a kernel using CUDA's runtime compilation (RTC) feature.

To integrate a tiny-cuda-nn model with a larger kernel in your app, you need to

  1. turn your kernel into a string,
  2. prepend the tiny-cuda-nn model's device function,
  3. pass the result to tiny-cuda-nn's runtime compilation API.

Here is an example that implements a minimal kernel using a tiny-cuda-nn model with 32 input dimensions and 16 output dimensions:

#include <tiny-cuda-nn/rtc_kernel.h>

auto model = tcnn::create_from_config(32 /* input dims */, 16 /* output dims */, ...);
auto fused_kernel = tcnn::CudaRtcKernel(
    "your_kernel",
    fmt::format(R"
        {MODEL_DEVICE_FUNCTION}
        __global__ void your_kernel(...) {
            // Get input to model from either registers or memory.
            tcnn::hvec<32> input = ...;
            // Call tiny-cuda-nn model. All 32 threads of the warp must be active here.
            tcnn::hvec<16> output = model_fun(nerf_in, params); 
            // Do something with the model output.
        }",
        fmt::arg("MODEL_DEVICE_FUNCTION", model->generate_device_function("model_fun")),
    )
);

uint32_t blocks = 1;
uint32_t threads = 128; // Must be multiple of 32 for neural networks to work.
uint32_t shmem_size = 0; // Can be any size that your_kernel needs.
cudaStream_t stream = nullptr; // Can be any stream.
fused_kernel.launch(blocks, threads, shmem_size, stream, ... /* params of your_kernel */);

And here is Instant NGP's NeRF integration with the JIT compiler for reference:

Other additions and changes since last release

  • Added unit tests to ensure the new JIT's output matches tiny-cuda-nn with JIT disabled
  • Fixed various miscellaneous bugs in the build system and tiny-cuda-nn itself