import torch
import torch.nn as nn
import torch.nn.functional as F
class DDPM(nn.Module):
def __init__(self, image_channels=3, hidden_channels=64, T=1000):
super().__init__()
self.T = T
# 定義擴(kuò)散模型的網(wǎng)絡(luò)結(jié)構(gòu)
self.encoder = nn.Sequential(
nn.Conv2d(image_channels, hidden_channels, kernel_size=3, padding=1),
nn.ReLU(),
nn.Conv2d(hidden_channels, hidden_channels * 2, kernel_size=3, padding=1),
nn.ReLU(),
nn.Conv2d(hidden_channels * 2, hidden_channels * 4, kernel_size=3, padding=1),
nn.ReLU(),
nn.Conv2d(hidden_channels * 4, hidden_channels * 8, kernel_size=3, padding=1)
)
self.decoder = nn.Sequential(
nn.ConvTranspose2d(hidden_channels * 8, hidden_channels * 4, kernel_size=2, stride=2),
nn.ReLU(),
nn.ConvTranspose2d(hidden_channels * 4, hidden_channels * 2, kernel_size=2, stride=2),
nn.ReLU(),
nn.ConvTranspose2d(hidden_channels * 2, hidden_channels, kernel_size=2, stride=2),
nn.ReLU(),
nn.ConvTranspose2d(hidden_channels, image_channels, kernel_size=2, stride=2)
)
def forward(self, x, t):
# 前向擴(kuò)散過(guò)程
embedded_t = self.time_embedding(t)
x = self.encoder(x)
x = x + embedded_t
x = self.decoder(x)
return x
def time_embedding(self, t):
# 時(shí)間嵌入
return torch.sin(t / self.T * torch.pi).unsqueeze(-1).unsqueeze(-1)
paper: DDPM
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請(qǐng)聯(lián)系作者
平臺(tái)聲明:文章內(nèi)容(如有圖片或視頻亦包括在內(nèi))由作者上傳并發(fā)布,文章內(nèi)容僅代表作者本人觀點(diǎn),簡(jiǎn)書系信息發(fā)布平臺(tái),僅提供信息存儲(chǔ)服務(wù)。
平臺(tái)聲明:文章內(nèi)容(如有圖片或視頻亦包括在內(nèi))由作者上傳并發(fā)布,文章內(nèi)容僅代表作者本人觀點(diǎn),簡(jiǎn)書系信息發(fā)布平臺(tái),僅提供信息存儲(chǔ)服務(wù)。
推薦閱讀更多精彩內(nèi)容
- 時(shí)間過(guò)去,論文還相當(dāng)于爛尾,還有很多篇想寫的,不是漂浮,要落地,要靜下來(lái),要沉下心,要專注,要深入,不要任何的表面...
- 2009-06-06 00:20 情感這個(gè)東西真的很怪,像生了根,拔不掉。今晚山那邊閃電撕裂烏云,冷雨中一排學(xué)生雷...
- Paper reading 1.Title Deep Snake for Real-Time Instance S...