Files
pytorch-playground/PytorchTutorialCodes/5_define_autograd_function_explanation.md
2025-09-29 08:59:47 +02:00

4.5 KiB

Custom autograd function in PyTorch: Step-by-step explanation

Let's break down what happens in your code, especially how you get the gradient numbers, what they mean, and how PyTorch's autograd system works when you define your own function.


1. What is a custom autograd function?

  • In PyTorch, you can create your own mathematical operation and tell PyTorch how to compute its gradient (how it changes with respect to its input).
  • You do this by subclassing torch.autograd.Function and implementing two methods:
    • forward: computes the output from the input (normal math)
    • backward: computes the gradient of the output with respect to the input (how the output changes if you nudge the input)

2. How do you get the gradient numbers?

  • Forward pass: You calculate the output for your function. Here, it's the Legendre polynomial:

P_3(x) = \frac{1}{2}(5x^3 - 3x)
  • Backward pass: You tell PyTorch the formula for the derivative of your function with respect to its input. For Legendre polynomial:

\frac{dP_3}{dx} = \frac{1}{2}(15x^2 - 3) = 1.5(5x^2 - 1)
  • When you call loss.backward(), PyTorch uses your backward method to calculate how much the loss would change if you changed the input a little bit. This is the gradient.
  • PyTorch automatically chains these gradients through all operations in your model, so you get the gradient of the loss with respect to every parameter (a, b, c, d).

3. What do these numbers mean?

  • The gradient for each parameter (e.g., a.grad, b.grad, etc.) tells you how much the loss would change if you nudged that parameter up or down.
  • If the gradient is large and positive, increasing the parameter will increase the loss (bad), so you want to decrease it.
  • If the gradient is large and negative, increasing the parameter will decrease the loss (good), so you want to increase it.
  • You use these numbers to update your parameters in the direction that makes the loss smaller (gradient descent).

4. How does PyTorch get these numbers?

  • During the forward pass, PyTorch builds a computation graph of all operations.
  • When you call loss.backward(), PyTorch walks backward through this graph, using the backward methods you defined (and built-in ones for standard operations), applying the chain rule to compute all gradients.
  • The numbers you see in .grad are the result of this process: they are the derivatives of the loss with respect to each parameter.

5. Step-by-step in your code

  1. Forward pass:
    • You compute y_pred = a + b * P3(c + d * x) using your custom Legendre polynomial function.
  2. Loss calculation:
    • You compute the loss as the sum of squared errors between y_pred and the true y.
  3. Backward pass:
    • You call loss.backward(). PyTorch uses your custom backward method to compute the gradients for all parameters.
  4. Update parameters:
    • You use the gradients to update a, b, c, and d so the loss gets smaller.

6. Summary

  • Custom autograd functions let you define both the math and the gradient for new operations.
  • Gradients are numbers that tell you how to change each parameter to make the loss smaller.
  • PyTorch computes these automatically using your backward method and the chain rule.
  • You use these gradients to update your parameters and train your model.

If you want, I can walk through a concrete example with numbers, or show how the gradients change during training! 1 2 3 4 5 6 7 8 9


  1. https://docs.pytorch.org/tutorials/beginner/examples_autograd/polynomial_custom_function.html ↩︎

  2. https://apxml.com/courses/advanced-pytorch/chapter-1-pytorch-internals-autograd/custom-autograd-functions ↩︎

  3. https://docs.pytorch.org/tutorials/beginner/blitz/autograd_tutorial.html ↩︎

  4. https://brsoff.github.io/tutorials/beginner/examples_autograd/two_layer_net_custom_function.html ↩︎

  5. https://ghamrouni.github.io/stn-tuto/beginner/examples_autograd/two_layer_net_custom_function.html ↩︎

  6. https://docs.pytorch.org/tutorials/beginner/pytorch_with_examples.html ↩︎

  7. https://docs.pytorch.org/tutorials/intermediate/custom_function_double_backward_tutorial.html ↩︎

  8. https://www.kaggle.com/code/peggy1502/learning-pytorch-2-new-autograd-functions ↩︎

  9. https://stackoverflow.com/questions/54586938/how-to-wrap-pytorch-functions-and-implement-autograd ↩︎