Skip to content

Commit b6341d9

Browse files
authored
Merge pull request #1 from goterria/main
Corrections for batch size
2 parents 963f329 + 6ccbed3 commit b6341d9

File tree

7 files changed

+396
-38
lines changed

7 files changed

+396
-38
lines changed

.gitignore

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,13 @@ dist/
66
build/
77
images/
88
outputs/
9+
output/
10+
DATASETS/
911
multirun/
1012
exp/
1113
handling/
1214
tests/
13-
scripts/
15+
wandb/
1416
*.code-workspace
1517

1618
configs/train/*

animaloc/eval/evaluators.py

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -200,11 +200,10 @@ def evaluate(self, returns: str = 'recall', wandb_flag: bool = False, viz: bool
200200
if i % self.print_freq == 0 or i == len(self.dataloader) - 1:
201201
fig = self._vizual(image = images, target = targets, output = output)
202202
wandb.log({'validation_vizuals': fig})
203-
204-
output = self.prepare_feeding(targets, output)
205-
206-
iter_metrics.feed(**output)
207-
iter_metrics.aggregate()
203+
for b in range(images.shape[0]):
204+
batch_output = self.prepare_feeding(dict(labels= targets['labels'][b], points= targets['points'][b]), (output[0][b].unsqueeze(0), output[1][b].unsqueeze(0)))
205+
iter_metrics.feed(**batch_output)
206+
iter_metrics.aggregate()
208207
if log_meters:
209208
logger.add_meter('n', sum(iter_metrics.tp) + sum(iter_metrics.fn))
210209
logger.add_meter('recall', round(iter_metrics.recall(),2))
@@ -226,8 +225,10 @@ def evaluate(self, returns: str = 'recall', wandb_flag: bool = False, viz: bool
226225
})
227226

228227
iter_metrics.flush()
229-
230-
self.metrics.feed(**output)
228+
for b in range(images.shape[0]):
229+
batch_output = self.prepare_feeding(dict(labels= targets['labels'][b], points= targets['points'][b]), (output[0][b].unsqueeze(0), output[1][b].unsqueeze(0)))
230+
self.metrics.feed(**batch_output)
231+
#self.metrics.feed(**output)
231232

232233
self._stored_metrics = self.metrics.copy()
233234

@@ -347,14 +348,16 @@ def post_stitcher(self, output: torch.Tensor) -> Any:
347348

348349
def prepare_feeding(self, targets: Dict[str, torch.Tensor], output: List[torch.Tensor]) -> dict:
349350

350-
gt_coords = [p[::-1] for p in targets['points'].squeeze(0).tolist()]
351-
gt_labels = targets['labels'].squeeze(0).tolist()
352-
351+
gt_coords = [p[::-1] for p in targets['points'].tolist()]
352+
gt_labels = targets['labels'].tolist()
353+
354+
ndim= numpy.array(gt_coords).ndim
353355
gt = dict(
354356
loc = gt_coords,
355357
labels = gt_labels
356358
)
357359

360+
358361
up = True
359362
if self.stitcher is not None:
360363
up = False
@@ -365,8 +368,8 @@ def prepare_feeding(self, targets: Dict[str, torch.Tensor], output: List[torch.T
365368
preds = dict(
366369
loc = locs[0],
367370
labels = labels[0],
368-
scores = scores[0],
369-
dscores = dscores[0]
371+
scores = scores[0], # class scores
372+
dscores = dscores[0] # heatmap scores
370373
)
371374

372375
return dict(gt = gt, preds = preds, est_count = counts[0])

docker/Dockerfile

Lines changed: 23 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,39 +1,38 @@
1-
FROM pytorch/pytorch:1.11.0-cuda11.3-cudnn8-runtime AS base
1+
2+
ARG PYTORCH="1.11.0"
3+
ARG CUDA="11.3"
4+
ARG CUDNN="8"
5+
6+
FROM pytorch/pytorch:${PYTORCH}-cuda${CUDA}-cudnn${CUDNN}-devel AS base
7+
ENV TORCH_CUDA_ARCH_LIST="6.0 6.1 7.0+PTX 9.0"
8+
ENV TORCH_NVCC_FLAGS="-Xfatbin -compress-all"
9+
210
RUN conda update conda && conda install pip && conda clean -afy
311
WORKDIR /herdnet
412
COPY environment-dev.yml ./
513
RUN conda env update -f environment-dev.yml -n base && conda clean -afy
614

7-
8-
#FROM base as dep_builder
9-
## Some deps must be built (e.g. against the conda GDAL)
10-
#RUN apt-get update \
11-
# && apt-get install -y gcc build-essential \
12-
# && rm -rf /var/lib/apt/lists/*
13-
#COPY pyproject.toml setup.cfg ./
14-
#COPY src/stactools/core/__init__.py src/stactools/core/
15-
## Install dependencies but remove the actual package
16-
#RUN pip install --prefix=/install .[all] \
17-
# && rm -r /install/lib/*/site-packages/stactools*
18-
19-
2015
FROM base AS dev
2116
# Install make for the docs build
22-
RUN apt-get update \
23-
&& apt-get install -y gcc make build-essential git \
17+
# solves a weired problem with NVIDIA with https://github.com/NVIDIA/nvidia-container-toolkit/issues/258#issuecomment-1903945418
18+
RUN \
19+
# Update nvidia GPG key
20+
rm /etc/apt/sources.list.d/cuda.list && \
21+
rm /etc/apt/sources.list.d/nvidia-ml.list && \
22+
apt-key del 7fa2af80 && \
23+
apt-get update && apt-get install -y --no-install-recommends wget && \
24+
wget https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2004/x86_64/cuda-keyring_1.0-1_all.deb && \
25+
dpkg -i cuda-keyring_1.0-1_all.deb && \
26+
apt-get update
27+
28+
RUN apt-get update \
29+
&& apt-get install -y --no-install-recommends python3-pyqt5 python3-pyqt5.qtwebengine unzip git \
2430
&& rm -rf /var/lib/apt/lists/*
25-
#COPY --from=dep_builder /install /opt/conda
26-
#RUN conda install -c conda-forge pandoc && conda clean -af
31+
2732
COPY requirements-dev.txt ./
2833
RUN pip install -r requirements-dev.txt
2934
COPY . ./
3035
# pre-commit run --all-files fails w/o this line
3136
RUN git init
3237
RUN pip install -e .
3338

34-
35-
#FROM base AS main
36-
#COPY --from=dep_builder /install /opt/conda
37-
#COPY src ./src
38-
#COPY pyproject.toml setup.cfg ./
39-
#RUN pip install .[all]

docker/build

100644100755
File mode changed.

docker/run

100644100755
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@ Run a console in a docker container with all prerequisites installed.
1717

1818
if [ "${BASH_SOURCE[0]}" = "${0}" ]; then
1919
docker run --rm -it --gpus all\
20+
-v /home/fous3401/DATASETS:/herdnet/DATASETS \
21+
--ipc=host \
2022
-v `pwd`:/herdnet \
2123
-p 8000:8000 \
2224
--entrypoint /bin/bash \

requirements-dev.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,4 +11,5 @@ pandas
1111
#opencv-python
1212
#opencv-python-headless
1313
hydra-core
14-
wandb
14+
wandb
15+
gdown

0 commit comments

Comments
 (0)