Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

How can I use the GPU on the macbook(m1) ? #286

Open
yyxlh opened this issue Oct 21, 2023 · 0 comments
Open

How can I use the GPU on the macbook(m1) ? #286

yyxlh opened this issue Oct 21, 2023 · 0 comments

Comments

@yyxlh
Copy link

yyxlh commented Oct 21, 2023

the torch can use the GPU,like the following code.

import torch
import math
#this ensures that the current MacOS version is at least 12.3+
print(torch.backends.mps.is_available())
#this ensures that the current current PyTorch installation was built with MPS activated.
print(torch.backends.mps.is_built())

dtype = torch.float
device = torch.device("mps")

#Create random input and output data
x = torch.linspace(-math.pi, math.pi, 2000, device=device, dtype=dtype)
y = torch.sin(x)

#Randomly initialize weights
a = torch.randn((), device=device, dtype=dtype)
b = torch.randn((), device=device, dtype=dtype)
c = torch.randn((), device=device, dtype=dtype)
d = torch.randn((), device=device, dtype=dtype)

learning_rate = 1e-6
for t in range(2000):
    # Forward pass: compute predicted y
    y_pred = a + b * x + c * x ** 2 + d * x ** 3

    #Compute and print loss
    loss = (y_pred - y).pow(2).sum().item()
    if t % 100 == 99:
        print(t, loss)

    #Backprop to compute gradients of a, b, c, d with respect to loss
    grad_y_pred = 2.0 * (y_pred - y)
    grad_a = grad_y_pred.sum()
    grad_b = (grad_y_pred * x).sum()
    grad_c = (grad_y_pred * x ** 2).sum()
    grad_d = (grad_y_pred * x ** 3).sum()

    #Update weights using gradient descent
    a -= learning_rate * grad_a
    b -= learning_rate * grad_b
    c -= learning_rate * grad_c
    d -= learning_rate * grad_d

    print(f'Result: y = {a.item()} + {b.item()} x + {c.item()} x^2 + {d.item()} x^3')

the output is:
True
True
99 1057.953369140625
199 732.12646484375
299 508.0294189453125
399 353.72833251953125
499 247.36788940429688
599 173.97360229492188
699 123.27410125732422
799 88.21514892578125
899 63.94678497314453
999 47.13111114501953
1099 35.46820831298828
1199 27.371381759643555
1299 21.745071411132812
1399 17.832063674926758
1499 15.10826301574707
1599 13.210683822631836
1699 11.887638092041016
1799 10.964457511901855
1899 10.319819450378418
1999 9.869365692138672
Result: y = 0.03163151070475578 + 0.8690047264099121 x + -0.005456964019685984 x^2 + -0.09507481753826141 x^3

进程已结束,退出代码为 0

Tasks

No tasks being tracked yet.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant