# Custom autograd function in PyTorch: Step-by-step explanation
Let's break down what happens in your code, especially how you get the gradient numbers, what they mean, and how PyTorch's autograd system works when you define your own function.
***
## 1. What is a custom autograd function?
- In PyTorch, you can create your own mathematical operation and tell PyTorch **how to compute its gradient** (how it changes with respect to its input).
- You do this by subclassing `torch.autograd.Function` and implementing two methods:
- `forward`: computes the output from the input (normal math)
- `backward`: computes the gradient of the output with respect to the input (how the output changes if you nudge the input)
***
## 2. How do you get the gradient numbers?
- **Forward pass:** You calculate the output for your function. Here, it's the Legendre polynomial:
$$
P_3(x) = \frac{1}{2}(5x^3 - 3x)
$$
- **Backward pass:** You tell PyTorch the formula for the derivative of your function with respect to its input. For Legendre polynomial:
$$
\frac{dP_3}{dx} = \frac{1}{2}(15x^2 - 3) = 1.5(5x^2 - 1)
$$
- When you call `loss.backward()`, PyTorch uses your `backward` method to calculate how much the loss would change if you changed the input a little bit. This is the **gradient**.
- PyTorch automatically chains these gradients through all operations in your model, so you get the gradient of the loss with respect to every parameter (a, b, c, d).
***
## 3. What do these numbers mean?
- The gradient for each parameter (e.g., `a.grad`, `b.grad`, etc.) tells you **how much the loss would change if you nudged that parameter up or down**.
- If the gradient is large and positive, increasing the parameter will increase the loss (bad), so you want to decrease it.
- If the gradient is large and negative, increasing the parameter will decrease the loss (good), so you want to increase it.
- You use these numbers to update your parameters in the direction that makes the loss smaller (gradient descent).
***
## 4. How does PyTorch get these numbers?
- During the forward pass, PyTorch builds a **computation graph** of all operations.
- When you call `loss.backward()`, PyTorch walks backward through this graph, using the `backward` methods you defined (and built-in ones for standard operations), applying the chain rule to compute all gradients.
- The numbers you see in `.grad` are the result of this process: they are the derivatives of the loss with respect to each parameter.
***
## 5. Step-by-step in your code
1. **Forward pass:**
- You compute `y_pred = a + b * P3(c + d * x)` using your custom Legendre polynomial function.
2. **Loss calculation:**
- You compute the loss as the sum of squared errors between `y_pred` and the true `y`.
3. **Backward pass:**
- You call `loss.backward()`. PyTorch uses your custom `backward` method to compute the gradients for all parameters.
4. **Update parameters:**
- You use the gradients to update `a`, `b`, `c`, and `d` so the loss gets smaller.
***
## 6. Summary
- **Custom autograd functions** let you define both the math and the gradient for new operations.
- **Gradients** are numbers that tell you how to change each parameter to make the loss smaller.
- **PyTorch computes these automatically** using your `backward` method and the chain rule.
- You use these gradients to update your parameters and train your model.
If you want, I can walk through a concrete example with numbers, or show how the gradients change during training!
[^1][^2][^3][^4][^5][^6][^7][^8][^9]
⁂
[^1]: https://docs.pytorch.org/tutorials/beginner/examples_autograd/polynomial_custom_function.html
[^2]: https://apxml.com/courses/advanced-pytorch/chapter-1-pytorch-internals-autograd/custom-autograd-functions
[^3]: https://docs.pytorch.org/tutorials/beginner/blitz/autograd_tutorial.html
[^4]: https://brsoff.github.io/tutorials/beginner/examples_autograd/two_layer_net_custom_function.html
[^5]: https://ghamrouni.github.io/stn-tuto/beginner/examples_autograd/two_layer_net_custom_function.html
[^6]: https://docs.pytorch.org/tutorials/beginner/pytorch_with_examples.html
[^7]: https://docs.pytorch.org/tutorials/intermediate/custom_function_double_backward_tutorial.html
[^8]: https://www.kaggle.com/code/peggy1502/learning-pytorch-2-new-autograd-functions
[^9]: https://stackoverflow.com/questions/54586938/how-to-wrap-pytorch-functions-and-implement-autograd