Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,13 @@ dist/
build/
images/
outputs/
output/
DATASETS/
multirun/
exp/
handling/
tests/
scripts/
wandb/
*.code-workspace

configs/train/*
Expand Down
53 changes: 53 additions & 0 deletions animaloc/data/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,59 @@ def __call__(

return image, target

@TRANSFORMS.register()
class AnimalDensity:
''' Compute animal density per tile '''

def __init__(
self,
max_animals: float = 100.0,
anno_type: str = 'binary'
) -> None:
'''
Args:
anno_type (str, optional): choose between 'binary' for bounding box or 'density'
for points. Defaults to 'binary'
'''

assert anno_type in ['binary', 'density'], \
f'Annotations type must be \'binary\' or \'density\', got \'{anno_type}\''
self.max_animals= max_animals
self.anno_type = anno_type

def __call__(
self,
image: Union[PIL.Image.Image, torch.Tensor],
target: Dict[str, torch.Tensor]
) -> Dict[str, torch.Tensor]:
'''
Args:.
image (PIL.Image.Image or torch.Tensor): image of reference [C,H,W], only for
pipeline convenience, original size is kept
target (dict): target containing at least 'boxes' (or 'points') and 'labels'
keys, with torch.Tensor as value. Labels must be integers!

Returns:
Dict[str, torch.Tensor]
the down-sampled target
'''

if isinstance(image, PIL.Image.Image):
image = torchvision.transforms.ToTensor()(image)

if self.anno_type == 'binary': # binary case (empty or not empty image)
if len(target['labels']): # we have annotations = not empty
target['labels'] = torch.as_tensor([1], dtype=torch.int64)
else: # we don't have annotations = empty
target['labels'] = torch.as_tensor([0], dtype=torch.int64)
elif self.anno_type == 'density': #TODO: complete for density case
if len(target['labels']): # we have annotations = not empty
target['labels'] = torch.as_tensor(len(target['labels'])/self.max_animals, dtype=torch.int64)
else: # we don't have annotations = empty
target['labels'] = torch.as_tensor([0], dtype=torch.int64)

return image, target['labels'].float()

@TRANSFORMS.register()
class PointsToMask:
''' Convert points annotations to mask with a buffer option '''
Expand Down
8 changes: 4 additions & 4 deletions animaloc/datasets/folder.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,13 +77,13 @@ def __init__(
self.folder_images = [i for i in os.listdir(self.root_dir)
if i.endswith(('.JPG','.jpg','.JPEG','.jpeg'))]

self._img_names = self.folder_images
self._img_names = self.folder_images
self.anno_keys = self.data.columns
self.data['from_folder'] = 0

self.data['from_folder'] = 0 # all images in the folder
folder_only_images = numpy.setdiff1d(self.folder_images, self.data['images'].unique().tolist())
folder_df = pandas.DataFrame(data=dict(images = folder_only_images))
folder_df['from_folder'] = 1
folder_df['from_folder'] = 1 # some have annotations

self.data = pandas.concat([self.data, folder_df], ignore_index=True).convert_dtypes()

Expand Down
70 changes: 57 additions & 13 deletions animaloc/eval/evaluators.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import numpy
import wandb
import matplotlib
from itertools import chain

matplotlib.use('Agg')

Expand Down Expand Up @@ -198,11 +199,10 @@ def evaluate(self, returns: str = 'recall', wandb_flag: bool = False, viz: bool
if i % self.print_freq == 0 or i == len(self.dataloader) - 1:
fig = self._vizual(image = images, target = targets, output = output)
wandb.log({'validation_vizuals': fig})

output = self.prepare_feeding(targets, output)

iter_metrics.feed(**output)
iter_metrics.aggregate()
for b in range(images.shape[0]):
batch_output = self.prepare_feeding(dict(labels= targets['labels'][b], points= targets['points'][b]), (output[0][b].unsqueeze(0), output[1][b].unsqueeze(0)))
iter_metrics.feed(**batch_output)
iter_metrics.aggregate()
if log_meters:
logger.add_meter('n', sum(iter_metrics.tp) + sum(iter_metrics.fn))
logger.add_meter('recall', round(iter_metrics.recall(),2))
Expand All @@ -224,8 +224,10 @@ def evaluate(self, returns: str = 'recall', wandb_flag: bool = False, viz: bool
})

iter_metrics.flush()

self.metrics.feed(**output)
for b in range(images.shape[0]):
batch_output = self.prepare_feeding(dict(labels= targets['labels'][b], points= targets['points'][b]), (output[0][b].unsqueeze(0), output[1][b].unsqueeze(0)))
self.metrics.feed(**batch_output)
#self.metrics.feed(**output)

self._stored_metrics = self.metrics.copy()

Expand Down Expand Up @@ -345,14 +347,16 @@ def post_stitcher(self, output: torch.Tensor) -> Any:

def prepare_feeding(self, targets: Dict[str, torch.Tensor], output: List[torch.Tensor]) -> dict:

gt_coords = [p[::-1] for p in targets['points'].squeeze(0).tolist()]
gt_labels = targets['labels'].squeeze(0).tolist()

gt_coords = [p[::-1] for p in targets['points'].tolist()]
gt_labels = targets['labels'].tolist()

ndim= numpy.array(gt_coords).ndim
gt = dict(
loc = gt_coords,
labels = gt_labels
)


up = True
if self.stitcher is not None:
up = False
Expand All @@ -363,8 +367,8 @@ def prepare_feeding(self, targets: Dict[str, torch.Tensor], output: List[torch.T
preds = dict(
loc = locs[0],
labels = labels[0],
scores = scores[0],
dscores = dscores[0]
scores = scores[0], # class scores
dscores = dscores[0] # heatmap scores
)

return dict(gt = gt, preds = preds, est_count = counts[0])
Expand All @@ -390,6 +394,27 @@ def prepare_feeding(self, targets: Dict[str, torch.Tensor], output: torch.Tensor

return dict(gt = gt, preds = preds, est_count = est_counts)

@EVALUATORS.register()
class DensityMapEvaluator(Evaluator):

def prepare_data(self, images: Any, targets: Any) -> tuple:
return images.to(self.device), targets

def prepare_feeding(self, targets: Dict[str, torch.Tensor], output: torch.Tensor) -> dict:

gt_coords = [p[::-1] for p in targets['points'].squeeze(0).tolist()]
gt_labels = targets['labels'].squeeze(0).tolist()

gt = dict(loc = gt_coords, labels = gt_labels)
preds = dict(loc = [], labels = [], scores = [])

_, idx = torch.max(output, dim=1)
masks = F.one_hot(idx, num_classes=output.shape[1]).permute(0,3,1,2)
output = (output * masks)
est_counts = output[0].sum(2).sum(1).tolist()

return dict(gt = gt, preds = preds, est_count = est_counts)

@EVALUATORS.register()
class FasterRCNNEvaluator(Evaluator):

Expand Down Expand Up @@ -420,4 +445,23 @@ def prepare_feeding(self, targets: List[dict], output: List[dict]) -> dict:
num_classes = self.metrics.num_classes - 1
counts = [preds['labels'].count(i+1) for i in range(num_classes)]

return dict(gt = gt, preds = preds, est_count = counts)
return dict(gt = gt, preds = preds, est_count = counts)

@EVALUATORS.register()
class TileEvaluator(Evaluator):

def prepare_data(self, images: Any, targets: Any) -> tuple:
return images.to(self.device), targets

def prepare_feeding(self, targets: Dict[str, torch.Tensor], output: torch.Tensor) -> dict:


gt_labels = list(chain.from_iterable(targets[0].tolist()))
gt_labels = [int(l+1) for l in gt_labels]
gt = dict(loc = [], labels = gt_labels)
preds = dict(loc = [], labels = [], scores = [])

scores= list(chain.from_iterable(output.tolist()))
labels= [2 if s>0 else 1 for s in scores]
preds = dict(loc = [], labels = labels, scores = scores)
return dict(gt = gt_labels, preds = labels)
43 changes: 34 additions & 9 deletions animaloc/eval/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -636,27 +636,52 @@ def __init__(self, num_classes: int = 2) -> None:
num_classes = num_classes + 1 # for convenience
super().__init__(0, num_classes)

def feed(self, gt: int, pred: int) -> tuple:
def feed(self, gt: int, preds: int) -> tuple:
'''
Args:
gt (int): numeric ground truth label
pred (int): numeric predicted label
'''

gt = dict(labels=[gt], loc=[(0,0)])
preds = dict(labels=[pred], loc=[(0,0)])
preds = dict(labels=[preds], loc=[(0,0)])

super().feed(gt, preds)

def matching(self, gt: dict, pred: dict) -> None:
gt_lab = gt['labels'][0]
p_lab = pred['labels'][0]
for g, p in zip(gt_lab, p_lab): #TODO: To be confirmed
if g == p:
self.tp[g-1] += 1
else:
self.fp[p-1] += 1
self.fn[g-1] += 1

self._confusion_matrix += confusion_matrix(gt_lab, p_lab, labels=list(range(1, self.num_classes)))

@METRICS.register()
class RegressionMetrics(Metrics):
''' Metrics class for regression type tasks '''

if gt_lab == p_lab:
self.tp[gt_lab-1] += 1
else:
self.fp[p_lab-1] += 1
self.fn[gt_lab-1] += 1
def __init__(self, num_classes: int = 2) -> None:
num_classes = num_classes + 1 # for convenience
super().__init__(0, num_classes)

def feed(self, gt: float, pred: float) -> tuple:
'''
Args:
gt (float): numeric ground truth value
pred (float): numeric predicted value
'''

gt = dict(labels=[gt], loc=[(0,0)])
preds = dict(labels=[pred], loc=[(0,0)])

self._confusion_matrix += confusion_matrix(
[gt_lab], [p_lab], labels=list(range(self.num_classes-1)))
super().feed(gt, preds)

def matching(self, gt: dict, pred: dict) -> None:
gt_lab = gt['labels'][0]
p_lab = pred['labels'][0]

diff= math.abs(gt_lab-p_lab) # L1-loss
1 change: 1 addition & 0 deletions animaloc/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,5 +21,6 @@
from .herdnet import *
from .utils import *
from .ss_dla import *
from .dla_backbone import *

__all__ = ['MODELS', *MODELS.registry_names]
93 changes: 93 additions & 0 deletions animaloc/models/dla_backbone.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
__copyright__ = \
"""
Copyright (C) 2022 University of Liège, Gembloux Agro-Bio Tech, Forest Is Life
All rights reserved.

This source code is under the CC BY-NC-SA-4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/).
It is to be used for academic research purposes only, no commercial use is permitted.

Please contact the author Alexandre Delplanque ([email protected]) for any questions.

Last modification: March 29, 2023
"""
__author__ = "Alexandre Delplanque"
__license__ = "CC BY-NC-SA 4.0"
__version__ = "0.2.0"


import torch

import torch.nn as nn
import numpy as np
import torchvision.transforms as T

from typing import Optional

from .register import MODELS

from . import dla as dla_modules


@MODELS.register()
class DLAEncoder(nn.Module):
''' DLA encoder architecture '''

def __init__(
self,
num_layers: int = 34,
num_classes: int = 2,
pretrained: bool = True,
):
'''
Args:
num_layers (int, optional): number of layers of DLA. Defaults to 34.
num_classes (int, optional): number of output classes, background included.
Defaults to 2.
pretrained (bool, optional): set False to disable pretrained DLA encoder parameters
from ImageNet. Defaults to True.
'''

super(DLAEncoder, self).__init__()

base_name = 'dla{}'.format(num_layers)

self.num_classes = num_classes

# backbone
base = dla_modules.__dict__[base_name](pretrained=pretrained, return_levels=True)
setattr(self, 'base_0', base)
setattr(self, 'channels_0', base.channels)

channels = self.channels_0


# bottleneck conv
self.bottleneck_conv = nn.Conv2d(
channels[-1], channels[-1],
kernel_size=1, stride=1,
padding=0, bias=True
)
self.pooling= nn.AvgPool2d(kernel_size= 16, stride=1, padding=0) # we take the average of each filter
self.cls_head = nn.Linear(512, 1) # binary head

def forward(self, input: torch.Tensor):

encode = self.base_0(input) # Nx512x16x16
bottleneck = self.bottleneck_conv(encode[-1])
bottleneck = self.pooling(bottleneck)
bottleneck= torch.reshape(bottleneck, (bottleneck.size()[0],-1)) # keeping the first dimension (samples)
encode[-1] = bottleneck # Nx512
cls = self.cls_head(encode[-1])

#cls = nn.functional.sigmoid(cls)
return cls

def freeze(self, layers: list) -> None:
''' Freeze all layers mentioned in the input list '''
for layer in layers:
self._freeze_layer(layer)

def _freeze_layer(self, layer_name: str) -> None:
for param in getattr(self, layer_name).parameters():
param.requires_grad = False

Loading