paper: DDPM

import torch
import torch.nn as nn
import torch.nn.functional as F

class DDPM(nn.Module):
    def __init__(self, image_channels=3, hidden_channels=64, T=1000):
        super().__init__()
        self.T = T
        # 定義擴(kuò)散模型的網(wǎng)絡(luò)結(jié)構(gòu)
        self.encoder = nn.Sequential(
            nn.Conv2d(image_channels, hidden_channels, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(hidden_channels, hidden_channels * 2, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(hidden_channels * 2, hidden_channels * 4, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(hidden_channels * 4, hidden_channels * 8, kernel_size=3, padding=1)
        )
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(hidden_channels * 8, hidden_channels * 4, kernel_size=2, stride=2),
            nn.ReLU(),
            nn.ConvTranspose2d(hidden_channels * 4, hidden_channels * 2, kernel_size=2, stride=2),
            nn.ReLU(),
            nn.ConvTranspose2d(hidden_channels * 2, hidden_channels, kernel_size=2, stride=2),
            nn.ReLU(),
            nn.ConvTranspose2d(hidden_channels, image_channels, kernel_size=2, stride=2)
        )

    def forward(self, x, t):
        # 前向擴(kuò)散過(guò)程
        embedded_t = self.time_embedding(t)
        x = self.encoder(x)
        x = x + embedded_t
        x = self.decoder(x)
        return x

    def time_embedding(self, t):
        # 時(shí)間嵌入
        return torch.sin(t / self.T * torch.pi).unsqueeze(-1).unsqueeze(-1)

?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請(qǐng)聯(lián)系作者
平臺(tái)聲明:文章內(nèi)容(如有圖片或視頻亦包括在內(nèi))由作者上傳并發(fā)布,文章內(nèi)容僅代表作者本人觀點(diǎn),簡(jiǎn)書系信息發(fā)布平臺(tái),僅提供信息存儲(chǔ)服務(wù)。

推薦閱讀更多精彩內(nèi)容