"""
Compute the gradient with Pytorch
"""
from torch.nn import CrossEntropyLoss, Linear
from utils import load_mnist_data
X, y = load_mnist_data()
model = Linear(784, 10)
lossfunc = CrossEntropyLoss()
loss = lossfunc(model(X), y)
loss.backward()
for param in model.parameters():
print(param.grad)