Flow chat

Image aug

2 Golbal view aug and 6 local view. Result a list of 8 imgs

def open_tif(fn, cls=torch.Tensor):
    im = skio.imread(str(fn))/10000
    im = im.transpose(1,2,0).astype('float32')
    return im
img=open_tif('./tifs/S2B_MSIL2A_20170924T093019_45_63.tif')
plt.imshow(img[:,:,3])
<matplotlib.image.AxesImage at 0x7f358c0e0ac0>

BAND_STATS = {
            'S2':{
                'mean': {
                    'B01': 340.76769064,
                    'B02': 429.9430203,
                    'B03': 614.21682446,
                    'B04': 590.23569706,
                    'B05': 950.68368468,
                    'B06': 1792.46290469,
                    'B07': 2075.46795189,
                    'B08': 2218.94553375,
                    'B8A': 2266.46036911,
                    'B09': 2246.0605464,
                    'B11': 1594.42694882,
                    'B12': 1009.32729131
                },
                'std': {
                    'B01': 554.81258967,
                    'B02': 572.41639287,
                    'B03': 582.87945694,
                    'B04': 675.88746967,
                    'B05': 729.89827633,
                    'B06': 1096.01480586,
                    'B07': 1273.45393088,
                    'B08': 1365.45589904,
                    'B8A': 1356.13789355,
                    'B09': 1302.3292881,
                    'B11': 1079.19066363,
                    'B12': 818.86747235
                }
            },
            'S1': {
                'mean': {
                    'VV': -12.619993741972035,
                    'VH': -19.29044597721542,
                    'VV/VH': 0.6525036195871579,
                },
                'std': {
                    'VV': 5.115911777546365,
                    'VH': 5.464428464912864,
                    'VV/VH': 30.75264076801808,
                },
                'min': {
                    'VV': -74.33214569091797,
                    'VH': -75.11137390136719,
                    'R': 3.21E-2
                },
                'max': {
                    'VV': 34.60696029663086,
                    'VH': 33.59768295288086,
                    'R': 1.08
                }
            }
        }
bands=['B02','B03', 'B04', 'B05','B06', 'B07', 'B11', 'B08','B8A', 'B12']
means=[BAND_STATS['S2']['mean'][band]/10000 for band in bands]
stds=[BAND_STATS['S2']['std'][band]/10000 for band in bands]
img=open_tif('./tifs/S2B_MSIL2A_20170924T093019_45_63.tif')
plt.imshow(img[:,:,0])
<matplotlib.image.AxesImage at 0x7f358bd7bdc0>
from albumentations.core.transforms_interface import ImageOnlyTransform
import random
class Channelaug(ImageOnlyTransform):

    def __init__(self, always_apply=False, p=0.5):
        super(Channelaug, self).__init__(always_apply, p)

    def apply(self, img, **params):
        temp=img[:,:,random.randint(0,9),np.newaxis]
        return np.repeat(temp,10,axis=2)

import albumentations as A
from albumentations.pytorch import ToTensorV2
from torchvision import transforms

trainaug=[]
# first gobal crop
globalaug1 = A.Compose([
                        Channelaug(always_apply=True),
                      A.HorizontalFlip(p=0.5),
                      A.ShiftScaleRotate(p=.5),
                      A.RandomResizedCrop(120,120,scale=(0.4, 1.)),
                      A.GaussianBlur(p=1.0),
#                       A.Solarize(threshold=0.5),
                      A.Normalize(mean=means,std=stds,max_pixel_value=1.0),
                         ToTensorV2()]                    
                    )

# second global crop
globalaug2 = A.Compose([A.HorizontalFlip(p=0.5),
                      A.ShiftScaleRotate(p=.5),
                      A.RandomResizedCrop(120,120,scale=(0.4, 1.)),
                      A.GaussianBlur(p=0.1),
                      A.Solarize(threshold=0.5,p=0.2),
                      A.Normalize(mean=means,std=stds,max_pixel_value=1.0),
                         ToTensorV2()]                    
                    )
# transformation for the local small crops
locaaug = A.Compose([A.HorizontalFlip(p=0.5),
                      A.ShiftScaleRotate(p=.5),
                      A.RandomResizedCrop(56,56,scale=(0.05, 0.4),always_apply=True),
                      A.GaussianBlur(p=0.5),
                      A.Normalize(mean=means,std=stds,max_pixel_value=1.0),
                         ToTensorV2()]                    
                    )

trainaug.append(globalaug1)
trainaug.append(globalaug2)
for _ in range(6):
    trainaug.append(locaaug)
result=[]
for i in range(len(trainaug)):
    result.append(trainaug[i](image=img)['image'])
aug=Channelaug(always_apply=True)
xx=aug(image=img)['image']

fig, axs = plt.subplots(2,4,dpi=150)
for i in range(8):
    axs[i//4,i%4].imshow(result[i][3])

MultiCropWrapper

Perform forward pass separately on each resolution input.The inputs corresponding to a single resolution are clubbed and single forward is run on the same resolution inputs. Hence we do several forward passes = number of different resolutions used. We then concatenate all the output features and run the head forward on these concatenated features.

class MultiCropWrapper(nn.Module):
    def __init__(self, backbone, head):
        super(MultiCropWrapper, self).__init__()
        # disable layers dedicated to ImageNet labels classification
        backbone.fc, backbone.head = nn.Identity(), nn.Identity()
        self.backbone = backbone
        self.head = head

    def forward(self, x):
        # convert to list
        if not isinstance(x, list):
            x = [x]
        idx_crops = torch.cumsum(torch.unique_consecutive(
            torch.tensor([inp.shape[-1] for inp in x]),
            return_counts=True,
        )[1], 0)
        start_idx, output = 0, torch.empty(0).to(x[0].device)
        for end_idx in idx_crops:
            _out = self.backbone(torch.cat(x[start_idx: end_idx]))
            # The output is a tuple with XCiT model. See:
            # https://github.com/facebookresearch/xcit/blob/master/xcit.py#L404-L405
            if isinstance(_out, tuple):
                _out = _out[0]
            # accumulate outputs
            output = torch.cat((output, _out))
            start_idx = end_idx
        # Run the head forward on the concatenated features.
        return self.head(output)
teacher_output = teacher(images[:2])
student_output = student(images)
loss = dino_loss(student_output, teacher_output, epoch)
# EMA update for the teacher
        with torch.no_grad():
            m = momentum_schedule[it]  # momentum parameter
            for param_q, param_k in zip(student.module.parameters(), teacher_without_ddp.parameters()):
                param_k.data.mul_(m).add_((1 - m) * param_q.detach().data)

Two golbal views are fed to teacher and all the others are fed to student

Loss

class DINOLoss(nn.Module):
    def __init__(self, out_dim, ncrops, warmup_teacher_temp, teacher_temp,
                 warmup_teacher_temp_epochs, nepochs, student_temp=0.1,
                 center_momentum=0.9):
        super().__init__()
        self.student_temp = student_temp
        self.center_momentum = center_momentum
        self.ncrops = ncrops
        self.register_buffer("center", torch.zeros(1, out_dim))
        # we apply a warm up for the teacher temperature because
        # a too high temperature makes the training instable at the beginning
        self.teacher_temp_schedule = np.concatenate((
            np.linspace(warmup_teacher_temp,
                        teacher_temp, warmup_teacher_temp_epochs),
            np.ones(nepochs - warmup_teacher_temp_epochs) * teacher_temp
        ))

    def forward(self, student_output, teacher_output, epoch):
        """
        Cross-entropy between softmax outputs of the teacher and student networks.
        """
        student_out = student_output / self.student_temp
        student_out = student_out.chunk(self.ncrops)

        # teacher centering and sharpening
        temp = self.teacher_temp_schedule[epoch]
        teacher_out = F.softmax((teacher_output - self.center) / temp, dim=-1)
        teacher_out = teacher_out.detach().chunk(2)

        total_loss = 0
        n_loss_terms = 0
        for iq, q in enumerate(teacher_out):
            for v in range(len(student_out)):
                if v == iq:
                    # we skip cases where student and teacher operate on the same view
                    continue
                loss = torch.sum(-q * F.log_softmax(student_out[v], dim=-1), dim=-1)
                total_loss += loss.mean()
                n_loss_terms += 1
        total_loss /= n_loss_terms
        self.update_center(teacher_output)
        return total_loss
    @torch.no_grad()
    def update_center(self, teacher_output):
        """
        Update center used for teacher output.
        """
        batch_center = torch.sum(teacher_output, dim=0, keepdim=True)
        dist.all_reduce(batch_center)
        batch_center = batch_center / (len(teacher_output) * dist.get_world_size())

        # ema update
        self.center = self.center * self.center_momentum + batch_center * (1 - self.center_momentum)