4.5 KiB
Custom autograd function in PyTorch: Step-by-step explanation
Let's break down what happens in your code, especially how you get the gradient numbers, what they mean, and how PyTorch's autograd system works when you define your own function.
1. What is a custom autograd function?
- In PyTorch, you can create your own mathematical operation and tell PyTorch how to compute its gradient (how it changes with respect to its input).
- You do this by subclassing
torch.autograd.Functionand implementing two methods:forward: computes the output from the input (normal math)backward: computes the gradient of the output with respect to the input (how the output changes if you nudge the input)
2. How do you get the gradient numbers?
- Forward pass: You calculate the output for your function. Here, it's the Legendre polynomial:
P_3(x) = \frac{1}{2}(5x^3 - 3x)
- Backward pass: You tell PyTorch the formula for the derivative of your function with respect to its input. For Legendre polynomial:
\frac{dP_3}{dx} = \frac{1}{2}(15x^2 - 3) = 1.5(5x^2 - 1)
- When you call
loss.backward(), PyTorch uses yourbackwardmethod to calculate how much the loss would change if you changed the input a little bit. This is the gradient. - PyTorch automatically chains these gradients through all operations in your model, so you get the gradient of the loss with respect to every parameter (a, b, c, d).
3. What do these numbers mean?
- The gradient for each parameter (e.g.,
a.grad,b.grad, etc.) tells you how much the loss would change if you nudged that parameter up or down. - If the gradient is large and positive, increasing the parameter will increase the loss (bad), so you want to decrease it.
- If the gradient is large and negative, increasing the parameter will decrease the loss (good), so you want to increase it.
- You use these numbers to update your parameters in the direction that makes the loss smaller (gradient descent).
4. How does PyTorch get these numbers?
- During the forward pass, PyTorch builds a computation graph of all operations.
- When you call
loss.backward(), PyTorch walks backward through this graph, using thebackwardmethods you defined (and built-in ones for standard operations), applying the chain rule to compute all gradients. - The numbers you see in
.gradare the result of this process: they are the derivatives of the loss with respect to each parameter.
5. Step-by-step in your code
- Forward pass:
- You compute
y_pred = a + b * P3(c + d * x)using your custom Legendre polynomial function.
- You compute
- Loss calculation:
- You compute the loss as the sum of squared errors between
y_predand the truey.
- You compute the loss as the sum of squared errors between
- Backward pass:
- You call
loss.backward(). PyTorch uses your custombackwardmethod to compute the gradients for all parameters.
- You call
- Update parameters:
- You use the gradients to update
a,b,c, anddso the loss gets smaller.
- You use the gradients to update
6. Summary
- Custom autograd functions let you define both the math and the gradient for new operations.
- Gradients are numbers that tell you how to change each parameter to make the loss smaller.
- PyTorch computes these automatically using your
backwardmethod and the chain rule. - You use these gradients to update your parameters and train your model.
If you want, I can walk through a concrete example with numbers, or show how the gradients change during training! 1 2 3 4 5 6 7 8 9
-
https://docs.pytorch.org/tutorials/beginner/examples_autograd/polynomial_custom_function.html ↩︎
-
https://apxml.com/courses/advanced-pytorch/chapter-1-pytorch-internals-autograd/custom-autograd-functions ↩︎
-
https://docs.pytorch.org/tutorials/beginner/blitz/autograd_tutorial.html ↩︎
-
https://brsoff.github.io/tutorials/beginner/examples_autograd/two_layer_net_custom_function.html ↩︎
-
https://ghamrouni.github.io/stn-tuto/beginner/examples_autograd/two_layer_net_custom_function.html ↩︎
-
https://docs.pytorch.org/tutorials/beginner/pytorch_with_examples.html ↩︎
-
https://docs.pytorch.org/tutorials/intermediate/custom_function_double_backward_tutorial.html ↩︎
-
https://www.kaggle.com/code/peggy1502/learning-pytorch-2-new-autograd-functions ↩︎
-
https://stackoverflow.com/questions/54586938/how-to-wrap-pytorch-functions-and-implement-autograd ↩︎