Work in progress
This commit is contained in:
@ -0,0 +1,94 @@
|
||||
<img src="https://r2cdn.perplexity.ai/pplx-full-logo-primary-dark%402x.png" style="height:64px;margin-right:32px"/>
|
||||
|
||||
# Custom autograd function in PyTorch: Step-by-step explanation
|
||||
|
||||
Let's break down what happens in your code, especially how you get the gradient numbers, what they mean, and how PyTorch's autograd system works when you define your own function.
|
||||
|
||||
***
|
||||
|
||||
## 1. What is a custom autograd function?
|
||||
|
||||
- In PyTorch, you can create your own mathematical operation and tell PyTorch **how to compute its gradient** (how it changes with respect to its input).
|
||||
- You do this by subclassing `torch.autograd.Function` and implementing two methods:
|
||||
- `forward`: computes the output from the input (normal math)
|
||||
- `backward`: computes the gradient of the output with respect to the input (how the output changes if you nudge the input)
|
||||
|
||||
***
|
||||
|
||||
## 2. How do you get the gradient numbers?
|
||||
|
||||
- **Forward pass:** You calculate the output for your function. Here, it's the Legendre polynomial:
|
||||
|
||||
$$
|
||||
P_3(x) = \frac{1}{2}(5x^3 - 3x)
|
||||
$$
|
||||
- **Backward pass:** You tell PyTorch the formula for the derivative of your function with respect to its input. For Legendre polynomial:
|
||||
|
||||
$$
|
||||
\frac{dP_3}{dx} = \frac{1}{2}(15x^2 - 3) = 1.5(5x^2 - 1)
|
||||
$$
|
||||
- When you call `loss.backward()`, PyTorch uses your `backward` method to calculate how much the loss would change if you changed the input a little bit. This is the **gradient**.
|
||||
- PyTorch automatically chains these gradients through all operations in your model, so you get the gradient of the loss with respect to every parameter (a, b, c, d).
|
||||
|
||||
***
|
||||
|
||||
## 3. What do these numbers mean?
|
||||
|
||||
- The gradient for each parameter (e.g., `a.grad`, `b.grad`, etc.) tells you **how much the loss would change if you nudged that parameter up or down**.
|
||||
- If the gradient is large and positive, increasing the parameter will increase the loss (bad), so you want to decrease it.
|
||||
- If the gradient is large and negative, increasing the parameter will decrease the loss (good), so you want to increase it.
|
||||
- You use these numbers to update your parameters in the direction that makes the loss smaller (gradient descent).
|
||||
|
||||
***
|
||||
|
||||
## 4. How does PyTorch get these numbers?
|
||||
|
||||
- During the forward pass, PyTorch builds a **computation graph** of all operations.
|
||||
- When you call `loss.backward()`, PyTorch walks backward through this graph, using the `backward` methods you defined (and built-in ones for standard operations), applying the chain rule to compute all gradients.
|
||||
- The numbers you see in `.grad` are the result of this process: they are the derivatives of the loss with respect to each parameter.
|
||||
|
||||
***
|
||||
|
||||
## 5. Step-by-step in your code
|
||||
|
||||
1. **Forward pass:**
|
||||
- You compute `y_pred = a + b * P3(c + d * x)` using your custom Legendre polynomial function.
|
||||
2. **Loss calculation:**
|
||||
- You compute the loss as the sum of squared errors between `y_pred` and the true `y`.
|
||||
3. **Backward pass:**
|
||||
- You call `loss.backward()`. PyTorch uses your custom `backward` method to compute the gradients for all parameters.
|
||||
4. **Update parameters:**
|
||||
- You use the gradients to update `a`, `b`, `c`, and `d` so the loss gets smaller.
|
||||
|
||||
***
|
||||
|
||||
## 6. Summary
|
||||
|
||||
- **Custom autograd functions** let you define both the math and the gradient for new operations.
|
||||
- **Gradients** are numbers that tell you how to change each parameter to make the loss smaller.
|
||||
- **PyTorch computes these automatically** using your `backward` method and the chain rule.
|
||||
- You use these gradients to update your parameters and train your model.
|
||||
|
||||
If you want, I can walk through a concrete example with numbers, or show how the gradients change during training!
|
||||
<span style="display:none">[^1][^2][^3][^4][^5][^6][^7][^8][^9]</span>
|
||||
|
||||
<div align="center">⁂</div>
|
||||
|
||||
[^1]: https://docs.pytorch.org/tutorials/beginner/examples_autograd/polynomial_custom_function.html
|
||||
|
||||
[^2]: https://apxml.com/courses/advanced-pytorch/chapter-1-pytorch-internals-autograd/custom-autograd-functions
|
||||
|
||||
[^3]: https://docs.pytorch.org/tutorials/beginner/blitz/autograd_tutorial.html
|
||||
|
||||
[^4]: https://brsoff.github.io/tutorials/beginner/examples_autograd/two_layer_net_custom_function.html
|
||||
|
||||
[^5]: https://ghamrouni.github.io/stn-tuto/beginner/examples_autograd/two_layer_net_custom_function.html
|
||||
|
||||
[^6]: https://docs.pytorch.org/tutorials/beginner/pytorch_with_examples.html
|
||||
|
||||
[^7]: https://docs.pytorch.org/tutorials/intermediate/custom_function_double_backward_tutorial.html
|
||||
|
||||
[^8]: https://www.kaggle.com/code/peggy1502/learning-pytorch-2-new-autograd-functions
|
||||
|
||||
[^9]: https://stackoverflow.com/questions/54586938/how-to-wrap-pytorch-functions-and-implement-autograd
|
||||
|
||||
Reference in New Issue
Block a user