[0, 0]])
x = tf.pad(x, paddings)
_, Hp, Wp, _ = x.shape
cyclic shift
if self.shift_size > 0:
shifted_x = tf.roll(x, shift=(-self.shift_size, -self.shift_size), axis=(1, 2))
else:
shifted_x = x
attn_mask = None
partition windows
x_windows = window_partition(shifted_x, self.window_size) # [nW*B, Mh, Mw, C]
x_windows = tf.reshape(x_windows, [-1, self.window_size * self.window_size, C]) # [nWB, MhMw, C]
W-MSA/SW-MSA
attn_windows = self.attn(x_windows, mask=attn_mask, training=training) # [nWB, MhMw, C]
merge windows
attn_windows = tf.reshape(attn_windows,
[-1, self.window_size, self.window_size, C]) # [nW*B, Mh, Mw, C]
shifted_x = window_reverse(attn_windows, self.window_size, Hp, Wp) # [B, H’, W’, C]
reverse cyclic shift
if self.shift_size > 0:
x = tf.roll(shifted_x, shift=(self.shift_size, self.shift_size), axis=(1, 2))
else:
x = shifted_x
if pad_r > 0 or pad_b > 0:
把前面pad的数据移除掉
x = tf.slice(x, begin=[0, 0, 0, 0], size=[B, H, W, C])
x = tf.reshape(x, [B, H * W, C])
FFN
x = shortcut + self.drop_path(x, training=training)
x = x + self.drop_path(self.mlp(self.norm2(x)), training=training)
return x
class BasicLayer(layers.Layer):
“”"
A basic Swin Transformer layer for one stage.
Args:
dim (int): Number of input channels.
depth (int): Number of blocks.
num_heads (int): Number of attention heads.
window_size (int): Local window size.
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
drop (float, optional): Dropout rate. Default: 0.0
attn_drop (float, optional): Attention dropout rate. Default: 0.0
drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
downsample (layer.Layer | None, optional): Downsample layer at the end of the layer. Default: None
“”"
def init(self, dim, depth, num_heads, window_size,
mlp_ratio=4., qkv_bias=True, drop=0., attn_drop=0.,
drop_path=0., downsample=None, name=None):
super().init(name=name)
self.dim = dim
self.depth = depth
self.window_size = window_size
self.shift_size = window_size // 2
build blocks
self.blocks = [
SwinTransformerBlock(dim=dim,
num_heads=num_heads,
window_size=window_size,
shift_size=0 if (i % 2 == 0) else self.shift_size,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
drop=drop,
attn_drop=attn_drop,
drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
name=f"block{i}")
for i in range(depth)
]
patch merging layer
if downsample is not None:
self.downsample = downsample(dim=dim, name=“downsample”)
else:
self.downsample = None
def create_mask(self, H, W):
calculate attention mask for SW-MSA
保证Hp和Wp是window_size的整数倍
Hp = int(np.ceil(H / self.window_size)) * self.window_size
Wp = int(np.ceil(W / self.window_size)) * self.window_size
拥有和feature map一样的通道排列顺序,方便后续window_partition
img_mask = np.zeros([1, Hp, Wp, 1]) # [1, Hp, Wp, 1]
h_slices = (slice(0, -self.window_size),
slice(-self.window_size, -self.shift_size),
slice(-self.shift_size, None))
w_slices = (slice(0, -self.window_size),
slice(-self.window_size, -self.shift_size),
slice(-self.shift_size, None))
cnt = 0
for h in h_slices:
for w in w_slices:
img_mask[:, h, w, :] = cnt
cnt += 1
img_mask = tf.convert_to_tensor(img_mask, dtype=tf.float32)
mask_windows = window_partition(img_mask, self.window_size) # [nW, Mh, Mw, 1]
mask_windows = tf.reshape(mask_windows, [-1, self.window_size * self.window_size]) # [nW, Mh*Mw]
[nW, 1, MhMw] - [nW, MhMw, 1]
attn_mask = tf.expand_dims(mask_windows, 1) - tf.expand_dims(mask_windows, 2)
attn_mask = tf.where(attn_mask != 0, -100.0, attn_mask)
attn_mask = tf.where(attn_mask == 0, 0.0, attn_mask)
return attn_mask
def call(self, x, H, W, training=None):
attn_mask = self.create_mask(H, W) # [nW, MhMw, MhMw]
for blk in self.blocks:
blk.H, blk.W = H, W
x = blk(x, attn_mask, training=training)
if self.downsample is not None:
x = self.downsample(x, H, W)
H, W = (H + 1) // 2, (W + 1) // 2
return x, H, W
class SwinTransformer(Model):
r"“” Swin Transformer
A PyTorch impl of : Swin Transformer: Hierarchical Vision Transformer using Shifted Windows -
https://arxiv.org/pdf/2103.14030
Args:
patch_size (int | tuple(int)): Patch size. Default: 4
num_classes (int): Number of classes for classification head. Default: 1000
embed_dim (int): Patch embedding dimension. Default: 96
depths (tuple(int)): Depth of each Swin Transformer layer.
num_heads (tuple(int)): Number of attention heads in different layers.
window_size (int): Window size. Default: 7
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4
qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
drop_rate (float): Dropout rate. Default: 0
attn_drop_rate (float): Attention dropout rate. Default: 0
drop_path_rate (float): Stochastic depth rate. Default: 0.1
norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
patch_norm (bool): If True, add normalization after patch embedding. Default: True
use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False
“”"
def init(self, patch_size=4, num_classes=1000,
embed_dim=96, depths=(2, 2, 6, 2), num_heads=(3, 6, 12, 24),
window_size=7, mlp_ratio=4., qkv_bias=True,
drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1,
norm_layer=layers.LayerNormalization, name=None, **kwargs):
super().init(name=name)
self.num_classes = num_classes
self.num_layers = len(depths)
self.embed_dim = embed_dim
self.mlp_ratio = mlp_ratio
split image into non-overlapping patches
self.patch_embed = PatchEmbed(patch_size=patch_size,
embed_dim=embed_dim,
norm_layer=norm_layer)
self.pos_drop = layers.Dropout(drop_rate)
stochastic depth decay rule
dpr = [x for x in np.linspace(0, drop_path_rate, sum(depths))]
build layers
self.stage_layers = []
for i_layer in range(self.num_layers):
注意这里构建的stage和论文图中有些差异
这里的stage不包含该stage的patch_merging层,包含的是下个stage的
layer = BasicLayer(dim=int(embed_dim * 2 ** i_layer),
depth=depths[i_layer],
num_heads=num_heads[i_layer],
window_size=window_size,
mlp_ratio=self.mlp_ratio,
qkv_bias=qkv_bias,
drop=drop_rate,
attn_drop=attn_drop_rate,
drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],
downsample=PatchMerging if (i_layer < self.num_layers - 1) else None,
name=f"layer{i_layer}")
self.stage_layers.append(layer)
self.norm = norm_layer(epsilon=1e-6, name=“norm”)
self.head = layers.Dense(num_classes,
kernel_initializepythonr=initializers.TruncatedNormal(stddev=0.02),
bias_initializer=initializers.Zeros(),
name=“head”)
self.head = layers.Dense(num_classes,
kernel_initializer=initializers.TruncatedNormal(stddev=0.02),
bias_initializer=initializers.Zeros(),
name=“head”)
对应forward
def call(self, x, training=None):
x, H, W = self.patch_embed(x) # x: [B, L, C]
x = self.pos_drop(x, training=training)
for layer in self.stage_layers:
x, H, W = layer(x, H, W, training=training)
x = self.norm(x) # [B, L, C]
x = tf.reduce_mean(x, axis=1)
x = self.head(x)
return x
def swin_tiny_patch4_window7_224(num_classes: int = 1000, **kwargs):
model = SwinTransformer(patch_size=4,
window_size=7,
embed_dim=96,
depths=(2, 2, 6, 2),
num_heads=(3, 6, 12, 24),
num_classes=num_classes,
name=“swin_tiny_patch4_window7”,
**kwargs)
return model
def swin_small_patch4_window7_224(num_classes: int = 1000, **kwargs):
model = SwinTransformer(in_chans=3,
patch_size=4,
window_size=7,
embed_dim=96,
depths=(2, 2, 18, 2),
num_heads=(3, 6, 12, 24),
num_classes=num_classes,
name=“swin_small_patch4_window7”,
**kwargs)
return model
def swin_base_patch4_window7_224(num_classes: int = 1000, **kwargs):
model = SwinTransformer(in_chans=3,
patch_size=4,
window_size=7,
embed_dim=128,
depths=(2, 2, 18, 2),
num_heads=(4, 8, 16, 32),
num_classes=num_classes,
name=“swin_base_patch4_window7”,
**kwargs)
return model
def swin_base_patch4_window12_384(num_classes: int = 1000, **kwargs):
model = SwinTransformer(in_chans=3,
patch_size=4,
window_size=12,
embed_dim=128,
depths=(2, 2, 18, 2),
num_heads=(4, 8, 16, 32),
num_classes=num_classes,
name=“swin_base_patch4_window12”,
**kwargs)
return model
def swin_base_patch4_window7_224_in22k(num_classes: int = 21841, **kwargs):
model = SwinTransformer(in_chans=3,
patch_size=4,
window_size=7,
embed_dim=128,
depths=(2, 2, 18, 2),
num_heads=(4, 8, 16, 32),
num_classes=num_classes,
name=“swin_base_patch4_window7”,
**kwargs)
return model
def swin_base_patch4_window12_384_in22k(num_classes: int = 21841, **kwargs):
model = SwinTransformer(in_chans=3,
patch_size=4,
window_size=12,
embed_dim=128,
depths=(2, 2, 18, 2),
num_heads=(4, 8, 16, 32),
num_classes=num_classes,
name=“swin_base_patch4_window12”,
**kwargs)
return model
def swin_large_patch4_window7_224_in22k(num_classes: int = 21841, **kwargs):
model = SwinTransformer(in_chans=3,
patch_size=4,
window_size=7,
embed_dim=192,
depths=(2, 2, 18, 2),
num_heads=(6, 12, 24, 48),
num_classes=num_classes,
name=“swin_large_patch4_window7”,
**kwargs)
return model
def swin_large_patch4_window12_384_in22k(num_classes: int = 21841, **kwargs):
model = SwinTransformer(in_chans=3,
patch_size=4,
window_size=12,
embed_dim=192,
depths=(2, 2, 18, 2),
num_heads=(6, 12, 24, 48),
num_classes=num_classes,
name=“swin_large_patch4_window12”,
**kwargs)
return model
下载模型,链接:https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_tiny_patch4_window7_224.pth
新建模型转换脚本weight_trans.py,插入代码:
import torch
from model import *
def main(weights_path: str,
model_name: str,
model: tf.keras.Model):
var_dict = {v.name.split(‘:’)[0]: v for v in model.weights}
weights_dict = torch.load(weights_path, map_location=“cpu”)[“model”]
w_dict = {}
for k, v in weights_dict.items():
if “patch_embed” in k:
k = k.replace(“.”, “/”)
if “proj” in k:
k = k.replace(“proj/weight”, “proj/kernel”)
if len(v.shape) > 1:
conv weights
v = np.transpose(v.numpy(), (2, 3, 1, 0)).astype(np.float32)
w_dict[k] = v
else:
bias
w_dict[k] = v
elif “norm” in k:
k = k.replace(“weight”, “gamma”).replace(“bias”, “beta”)
w_dict[k] = v
elif “layers” in k:
k = k.replace(“layers”, “layer”)
split_k = k.split(“.”)
layer_id = split_k[0] + split_k[1]
if “block” in k:
split_k[2] = “block”
black_id = split_k[2] + split_k[3]
k = “/”.join([layer_id, black_id, *split_k[4:]])
if “attn” in k or “mlp” in k:
k = k.replace(“weight”, “kernel”)
if “kernel” in k:
v = np.transpose(v.numpy(), (1, 0)).astype(np.float32)
elif “norm” in k:
k = k.replace(“weight”, “gamma”).replace(“bias”, “beta”)
w_dict[k] = v
elif “downsample” in k:
k = “/”.join([layer_id, *split_k[2:]])
if “reduction” in k:
k = k.replace(“weight”, “kernel”)
if “kernel” in k: