jax-best-practices

from mindrally/skills

240+ Claude Code skills converted from Cursor rules. Expert coding guidelines for every major framework and language.

0 stars0 forksUpdated Jan 23, 2026
npx skills add https://github.com/mindrally/skills --skill jax-best-practices

SKILL.md

JAX Best Practices

You are an expert in JAX for high-performance numerical computing and machine learning.

Core Principles

  • Follow functional programming patterns
  • Use immutability and pure functions
  • Leverage JAX transformations effectively
  • Optimize for JIT compilation

Key Transformations

jax.jit

  • Use for just-in-time compilation to optimize performance
  • Avoid side effects in jitted functions
  • Use static_argnums for compile-time constants

jax.vmap

  • Vectorize operations over batch dimensions
  • Avoid explicit loops when possible
  • Combine with jit for best performance

jax.grad

  • Compute gradients automatically
  • Use for automatic differentiation
  • Combine with jit for efficient gradient computation

Best Practices

  • Write pure functions without side effects
  • Use JAX arrays instead of NumPy where possible
  • Leverage random key splitting properly
  • Profile and optimize hot paths

Performance

  • Minimize Python overhead in hot loops
  • Use appropriate dtypes
  • Batch operations when possible
  • Profile with JAX profiler

Common Patterns

  • Use pytrees for nested data structures
  • Implement custom vjp/jvp when needed
  • Leverage sharding for multi-device
  • Use checkpointing for memory efficiency

Repository Stats

Stars0
Forks0
LicenseApache License 2.0