PyTorch¶
Building on the previous auto-diff example, the switch to PyTorch and its auto-grad capabilities is trivial.
Initialization¶
The critical line that changes is loading the module:
# Load torch wrapped module.
module = spy.TorchModule.load_from_file(device, "example.slang")
Here, rather than simply write spy.Module.load_from_file
, we write spy.TorchModule.load_from_file
. From here, all structures or functions utilizing the module will support PyTorch tensors and be injected into PyTorch’s auto-grad graph.
In future SlangPy versions we intend to remove the need for wrapping altogether, instead auto-detecting the need for auto-grad support at the point of call.
Creating a tensor¶
Now, rather than use a SlangPy Tensor
, we create a torch.Tensor
tensor to store the inputs:
# Create a tensor
x = torch.tensor([1, 2, 3, 4], dtype=torch.float32, device='cuda', requires_grad=True)
Note:
We set
requires_grad=True
to tell PyTorch to track the gradients of this tensor.We set
device='cuda'
to ensure the tensor is on the GPU.
Running the kernel¶
Calling the function is pretty much unchanged, however calculation of gradients is now done via PyTorch:
# Evaluate the polynomial. Result will now default to a torch tensor.
# Expecting result = 2x^2 + 8x - 1
result = module.polynomial(a=2, b=8, c=-1, x=x)
print(result)
# Run backward pass on result, using result grad == 1
# to get the gradient with respect to x
result.backward(torch.ones_like(result))
print(x.grad)
This works because the wrapped PyTorch module automatically wrapped the call to polynomial in a custom autograd function. As a result, the call to result.backwards automatically called module.polynomial.bwds.
A word on performance¶
This example showed a very basic use of PyTorch’s auto-grad capabilities. However in practice, the switch from a CUDA PyTorch context to a D3D or Vulkan context has an overhead. Typically, very simple logic will be faster in PyTorch. However as functions become more complex, writing them as simple scalar processes that are vectorized by SlangPy and wrapped in PyTorch quickly becomes apparent.
Additionally, we intend to add a pure CUDA backend to SlangPy in the future, which will allow for seamless switching between PyTorch and SlangPy contexts.
Summary¶
That’s it! You can now use PyTorch tensors with SlangPy, and take advantage of PyTorch’s auto-grad capabilities. This example covered:
Initialization with a TorchModule to enable PyTorch support
Use of PyTorch’s .backward process to track an auto-grad graph and back propagate gradients.
Performance considerations when wrapping Slang code with PyTorch.