动手学深度学习:线性回归神经网络

news/2025/2/24 7:06:38

从零实现线性回归

生成数据

import torch
def synthetic_data(w,b,num_examples):
    X=torch.normal(0,1,(num_examples,len(w)))
    y=torch.matmul(X,w)+b
    y+=torch.normal(0,0.1,y.shape)
    return X,y.reshape((-1,1))
true_w=torch.tensor([2,-3.4]).reshape(2,1)
true_b=4.2
features,label=synthetic_data(true_w,true_b,500)
print(f'features的值{features[0]},label的值{label[0]}')

在这里插入图片描述

可视化

import matplotlib.pyplot as plt
# 创建三维坐标系
fig = plt.figure()
ax = fig.add_subplot(projection='3d')
ax.scatter(features[:,0],features[:,1],label)

在这里插入图片描述

plt.scatter(features[:,0],label)

在这里插入图片描述

plt.scatter(features[:,1],label)

在这里插入图片描述

读取数据集

import random
def data_iter(batch_size,features,label):
    num_examples=len(features)
    indices=list(range(num_examples))
    random.shuffle(indices)
    for i in range(0,num_examples,batch_size):
        batch_indices=torch.tensor(indices[i:min(i+batch_size,num_examples)])
        yield features[batch_indices],label[batch_indices]
for x,y in data_iter(10,features,label):
    print(x,y)
    break

在这里插入图片描述

搭建模型

def linreg(x,w,b):
    return torch.matmul(x,w)+b

损失函数

def squared_loss(hat_y,y):
    return (hat_y-y.reshape(hat_y.shape))**2/2

随机梯度下降

def sgd(params,lr,batch_size):
    with torch.no_grad():
        for param in params:
            param-=lr*param.grad/batch_size
            param.grad.zero_()

训练

w=torch.normal(0,0.01,size=(2,1),requires_grad=True)
b=torch.tensor(1.0,requires_grad=True)
lr=0.03
num_epochs=5
batch_size=5
net=linreg
loss=squared_loss
for epoch in range(num_epochs):
    for x,y in data_iter(batch_size,features,label):
        l=loss(net(x,w,b),y)
        l.sum().backward()
        sgd([w,b],lr,batch_size)
    with torch.no_grad():
        train_l=loss(net(features,w,b),label)
        print(f'第{epoch+1}轮损失为{train_l.mean()}')

在这里插入图片描述

w-true_w,b-true_b

在这里插入图片描述

线性回归简明实现

读取数据

from torch.utils import data
def data_iter(data_array,batch_size,is_train=True):
    datasets=data.TensorDataset(*data_array)
    return data.DataLoader(datasets,batch_size,shuffle=is_train)

神经网络

from torch import nn
net=nn.Sequential(nn.Linear(2,1))
net[0].weight.data.normal_(0,0.01)
net[0].bias.data.fill_(0)

损失函数

loss=nn.MSELoss()

优化器

trainer=torch.optim.SGD(net.parameters(),lr=0.03)

训练

epochs_num=5
for epoch in range(epochs_num):
    for x,y in data_iter((features,label),batch_size):
        l=loss(net(x),y)
        trainer.zero_grad()
        l.backward()
        trainer.step()
    l=loss(net(features),label)
    print(f'第{epoch+1}轮损失为{l}')

在这里插入图片描述

net[0].weight.data.reshape(true_w.shape)-true_w,net[0].bias.data-true_b

在这里插入图片描述


http://www.niftyadmin.cn/n/5864060.html

相关文章

第17篇:网络请求与Axios集成

目标:掌握在Vue3中规范地发起HTTP请求 1. 安装与基础配置 npm install axios // src/utils/request.js import axios from axios const service axios.create({ baseURL: https://api.example.com, timeout: 10000 }) export default service 2. 基础请…

正则表达式用法及其示例:匹配、查找、替换文本中的模式,及QT下如何使用正则表达式。

当然!正则表达式是一种强大的工具,用于匹配、查找、替换文本中的模式。下面是一些常见的正则表达式用法及其示例。 1、基本语法 基本元字符和语法 .:匹配任意单个字符(除了换行符)。^:匹配输入字符串的开…

从零开始玩转TensorFlow:小明的机器学习故事 4

探索深度学习 1 场景故事:小明的灵感 前不久,小明一直在用传统的机器学习方法(如线性回归、逻辑回归)来预测学校篮球比赛的胜负。虽然在朋友们看来已经很不错了,但小明发现一个问题:当比赛数据越来越多、…

【前端】react大全一本通

定期更新,建议关注收藏点赞。 内容源自本人以前的各种笔记,这里重新汇总补充一下。 目录 简介生命周期函数PWA(渐进式Web应用) 使用教程JSX(JavaScript XML)虚拟DOM 简介 React.js 是一个帮助你构建页面 U…

devops-Jenkins一键部署多台实例

Deckerfile # 第一阶段:构建阶段 FROM maven:3.8.4-openjdk-17 AS build # 设置工作目录 WORKDIR /app # 复制项目的 pom.xml 文件,先下载依赖以利用缓存 COPY pom.xml . RUN mvn dependency:go-offline # 复制项目源代码 COPY src ./src # 打包项目 RUN…

3D Gaussian Splatting(3DGS)的核心原理

3D Gaussian Splatting(3DGS)的核心原理 1. 基本概念 3D Gaussian Splatting(3DGS) 是一种基于 高斯分布的点云表示与渲染技术,核心思想是将三维场景建模为一系列 可学习的高斯分布,每个高斯分布具有以下…

使用 Promptic 进行对话管理需要具备python技术中的那些编程能力?

使用 Promptic 进行对话管理时,需要掌握一些基础的编程知识和技能,以下是详细说明: 1. Python 编程基础 Promptic 是一个基于 Python 的开发框架,因此需要具备一定的 Python 编程能力,包括: 函数定义与使用:了解如何定义函数、使用参数和返回值。类型注解:熟悉 Python…

【三十四周】文献阅读:DeepPose: 通过深度神经网络实现人类姿态估计

目录 摘要AbstractDeepPose: 通过深度神经网络实现人类姿态估计研究背景创新点方法论归一化网络结构级联细化流程 代码实践局限性实验结果总结 摘要 人体姿态估计旨在通过图像定位人体关节,是计算机视觉领域的核心问题之一。传统方法多基于局部检测与图模型&#x…