Just-in-Time Compilation
Just-in-time (JIT) compilation is a runtime technique where code is compiled into machine code on the fly, right before it is executed, to improve performance. Codeflash supports optimizing numerical code using Just-in-Time (JIT) compilation via leveraging JIT compilers from the Numba, PyTorch, TensorFlow, and JAX frameworks.How Codeflash Optimizes with JIT
When Codeflash identifies a function that could benefit from JIT compilation, it:- Rewrites the code in a JIT-compatible format, which may involve breaking down complex functions into separate JIT-compiled components.
- Generates appropriate tests that are compatible with JIT-compiled code, carefully handling data types since JIT compilers have stricter input type requirements.
- Disables JIT compilation when running coverage and tracer. This ensures accurate coverage and trace data, since both rely on Python bytecode execution. JIT-compiled code bypasses Python bytecode, so it would prevent proper tracking.
- Disables the Line Profiler for JIT compiled code. It could be possible to disable JIT compilation and run the line profiler, but that would lead to inaccurate information which could misguide the optimization process.
Configuration
JIT compilation support is enabled automatically in Codeflash. You donβt need to modify any configuration to enable JIT-based optimizations. Codeflash will automatically detect when JIT compilation could improve performance and suggest appropriate optimizations.When JIT Compilation Helps
JIT compilation is most effective for:- Numerical computations with loops that canβt be easily vectorized.
- Custom algorithms not covered by existing optimized libraries.
- Functions that are called repeatedly with consistent input types.
- Code that benefits from hardware-specific optimizations (SIMD acceleration).
Example
Function Definition
Benchmarking Snippet (replace cuda with mps to run on your Mac)
torch.compile is the only viable option because
- Already vectorized - All operations are already PyTorch tensor ops.
- Multiple Kernel Launches - Uncompiled code launches ~10 separate kernels.
torch.compilefuses them into 1-2 kernels, eliminating kernel launch overhead. - No algorithmic improvement - The computation itself is already optimal.
- Python overhead elimination - Removes Python interpreter overhead between operations.
When JIT Compilation May Not Help
JIT compilation may not provide speedups when:- The code already uses highly optimized libraries (e.g.,
NumPywithMKL,cuBLAS,cuDNN). - Functions have variable input types or shapes that prevent effective compilation.
- The compilation overhead exceeds the runtime savings for short-running functions.
- The code relies heavily on Python objects or dynamic features that JIT compilers canβt optimize.
Example
Function Definition
Benchmarking Snippet (replace cuda with mps to run on your Mac)
torch.compile is detrimental here:
- Graph breaks -
.item()forces a graph break, negating compile benefits. - Recompilation overhead - Different branches cause expensive recompilation each time.
- Dynamic control flow - Data-dependent conditionals canβt be optimized away.
- Already optimized ops -
matmulalready usescuBLAS; compile adds overhead without benefit.
Better Optimization Strategy
- Eliminate
.item()- Keep computation on GPU. - Branchless execution - Compute both paths, blend results.
- Vectorization - Replace conditionals with masked operations.
- Reduce Python overhead - Minimize host-device synchronization.
Supported JIT Frameworks
Each framework uses different compilation strategies to accelerate Python code:Numba (CPU Code)
Numba compiles Python functions to optimized machine code using the LLVM compiler infrastructure. Codeflash can suggest Numba optimizations that use:@jit- General-purpose JIT compilation with optional flags.nopython=True- Compiles to machine code without falling back to the Python interpreter.fastmath=True- Uses aggressive floating-point optimizations via LLVMβs fastmath flag.cache=True- cache compiled function to disk which reduces future runtimes.parallel=True- Parallelizes code inside loops.
PyTorch
PyTorch provides JIT compilation throughtorch.compile(), the recommended compilation API introduced in PyTorch 2.0. It uses TorchDynamo to capture Python bytecode and TorchInductor to generate optimized kernels.
torch.compile()- Compiles a function or module for optimized execution.mode- Controls the compilation strategy:"default"- Balanced compilation with moderate optimization."reduce-overhead"- Minimizes Python overhead using CUDA graphs, ideal for small batches."max-autotune"- Spends more time auto-tuning to find the fastest kernels.
fullgraph=True- Requires the entire function to be captured as a single graph. Raises an error if graph breaks occur, useful for ensuring complete optimization.dynamic=True- Enables dynamic shape support, allowing the compiled function to handle varying input sizes without recompilation.
TensorFlow
TensorFlow uses@tf.function to compile Python functions into optimized TensorFlow graphs. When combined with XLA (Accelerated Linear Algebra), it can generate highly optimized machine code for both CPU and GPU.
@tf.function- Converts Python functions into TensorFlow graphs for optimized execution.jit_compile=True- Enables XLA compilation, which performs whole-function optimization including operation fusion, memory layout optimization, and target-specific code generation.
JAX
JAX uses XLA to JIT compile pure functions into optimized machine code. It emphasizes functional programming patterns and captures side-effect-free operations for optimization.@jax.jit- JIT compiles functions using XLA with automatic operation fusion.