JAX快速入門

JAX由autograd和XLA(accelerate linear algebra)組成

  • 做函數優化(感知機)
import numpy as np

def func(x,a,b):
    y = x*a+b
    return y

def loss_function(weights,x,y):
    a,b = weights
    y_hat = func(x,a,b)
    return (y_hat-y)**2

jax的作用就是引入梯度

from jax import grad
def f(x):
    return x**2
df = grad(f)
df(3.0)  #返回6.0
a = np.random.random()
b = np.random.random()
weights = [a,b]
x = np.array([np.random.random() for _ in range(1000)])
y = np.array([3*xx+4 for xx in x])


grad_func = grad(loss_func)
grad_func(weights,x,y)



learning_rate = 0.001
for i in range(100):
    loss = loss_func(weights,x,y)
    da,db = grad_func(weights,x,y)
    a = a - learning_rate*da
    b = b - learning_rate*db
?著作權歸作者所有,轉載或內容合作請聯系作者
平臺聲明:文章內容(如有圖片或視頻亦包括在內)由作者上傳并發布,文章內容僅代表作者本人觀點,簡書系信息發布平臺,僅提供信息存儲服務。

推薦閱讀更多精彩內容