Skip to content

Commit 9ecb758

Browse files
authored
[EM] Add basic distributed GPU tests. (#10861)
- Split Hist and Approx tests in unittests. - Basic GPU tests for distributed.
1 parent 92f1c48 commit 9ecb758

File tree

4 files changed

+90
-136
lines changed

4 files changed

+90
-136
lines changed

python-package/xgboost/testing/dask.py

Lines changed: 77 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,16 @@
11
"""Tests for dask shared by different test modules."""
22

3-
from typing import Literal
3+
from typing import List, Literal, cast
44

55
import numpy as np
66
import pandas as pd
77
from dask import array as da
88
from dask import dataframe as dd
9-
from distributed import Client
9+
from distributed import Client, get_worker
1010

1111
import xgboost as xgb
12+
import xgboost.testing as tm
13+
from xgboost.compat import concat
1214
from xgboost.testing.updater import get_basescore
1315

1416

@@ -91,3 +93,76 @@ def check_uneven_nan(
9193
dd.from_pandas(X, npartitions=n_workers),
9294
dd.from_pandas(y, npartitions=n_workers),
9395
)
96+
97+
98+
def check_external_memory( # pylint: disable=too-many-locals
99+
worker_id: int,
100+
n_workers: int,
101+
device: str,
102+
comm_args: dict,
103+
is_qdm: bool,
104+
) -> None:
105+
"""Basic checks for distributed external memory."""
106+
n_samples_per_batch = 32
107+
n_features = 4
108+
n_batches = 16
109+
use_cupy = device != "cpu"
110+
111+
n_threads = get_worker().state.nthreads
112+
with xgb.collective.CommunicatorContext(dmlc_communicator="rabit", **comm_args):
113+
it = tm.IteratorForTest(
114+
*tm.make_batches(
115+
n_samples_per_batch,
116+
n_features,
117+
n_batches,
118+
use_cupy=use_cupy,
119+
random_state=worker_id,
120+
),
121+
cache="cache",
122+
)
123+
if is_qdm:
124+
Xy: xgb.DMatrix = xgb.ExtMemQuantileDMatrix(it, nthread=n_threads)
125+
else:
126+
Xy = xgb.DMatrix(it, nthread=n_threads)
127+
results: xgb.callback.TrainingCallback.EvalsLog = {}
128+
xgb.train(
129+
{"tree_method": "hist", "nthread": n_threads, "device": device},
130+
Xy,
131+
evals=[(Xy, "Train")],
132+
num_boost_round=32,
133+
evals_result=results,
134+
)
135+
assert tm.non_increasing(cast(List[float], results["Train"]["rmse"]))
136+
137+
lx, ly, lw = [], [], []
138+
for i in range(n_workers):
139+
x, y, w = tm.make_batches(
140+
n_samples_per_batch,
141+
n_features,
142+
n_batches,
143+
use_cupy=use_cupy,
144+
random_state=i,
145+
)
146+
lx.extend(x)
147+
ly.extend(y)
148+
lw.extend(w)
149+
150+
X = concat(lx)
151+
yconcat = concat(ly)
152+
wconcat = concat(lw)
153+
if is_qdm:
154+
Xy = xgb.QuantileDMatrix(X, yconcat, weight=wconcat, nthread=n_threads)
155+
else:
156+
Xy = xgb.DMatrix(X, yconcat, weight=wconcat, nthread=n_threads)
157+
158+
results_local: xgb.callback.TrainingCallback.EvalsLog = {}
159+
xgb.train(
160+
{"tree_method": "hist", "nthread": n_threads, "device": device},
161+
Xy,
162+
evals=[(Xy, "Train")],
163+
num_boost_round=32,
164+
evals_result=results_local,
165+
)
166+
np.testing.assert_allclose(
167+
results["Train"]["rmse"], results_local["Train"]["rmse"], rtol=1e-4
168+
)

tests/cpp/tree/test_gpu_hist.cu

Lines changed: 0 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -318,55 +318,4 @@ TEST_F(MGPUHistTest, HistColumnSplit) {
318318
this->DoTest([&] { VerifyHistColumnSplit(kRows, kCols, expected_tree); }, true);
319319
this->DoTest([&] { VerifyHistColumnSplit(kRows, kCols, expected_tree); }, false);
320320
}
321-
322-
namespace {
323-
RegTree GetApproxTree(Context const* ctx, DMatrix* dmat) {
324-
ObjInfo task{ObjInfo::kRegression};
325-
std::unique_ptr<TreeUpdater> approx_maker{TreeUpdater::Create("grow_gpu_approx", ctx, &task)};
326-
approx_maker->Configure(Args{});
327-
328-
TrainParam param;
329-
param.UpdateAllowUnknown(Args{});
330-
331-
linalg::Matrix<GradientPair> gpair({dmat->Info().num_row_}, ctx->Device());
332-
gpair.Data()->Copy(GenerateRandomGradients(dmat->Info().num_row_));
333-
334-
std::vector<HostDeviceVector<bst_node_t>> position(1);
335-
RegTree tree;
336-
approx_maker->Update(&param, &gpair, dmat, common::Span<HostDeviceVector<bst_node_t>>{position},
337-
{&tree});
338-
return tree;
339-
}
340-
341-
void VerifyApproxColumnSplit(bst_idx_t rows, bst_feature_t cols, RegTree const& expected_tree) {
342-
auto ctx = MakeCUDACtx(DistGpuIdx());
343-
344-
auto Xy = RandomDataGenerator{rows, cols, 0}.GenerateDMatrix(true);
345-
auto const world_size = collective::GetWorldSize();
346-
auto const rank = collective::GetRank();
347-
std::unique_ptr<DMatrix> sliced{Xy->SliceCol(world_size, rank)};
348-
349-
RegTree tree = GetApproxTree(&ctx, sliced.get());
350-
351-
Json json{Object{}};
352-
tree.SaveModel(&json);
353-
Json expected_json{Object{}};
354-
expected_tree.SaveModel(&expected_json);
355-
ASSERT_EQ(json, expected_json);
356-
}
357-
} // anonymous namespace
358-
359-
class MGPUApproxTest : public collective::BaseMGPUTest {};
360-
361-
TEST_F(MGPUApproxTest, GPUApproxColumnSplit) {
362-
auto constexpr kRows = 32;
363-
auto constexpr kCols = 16;
364-
365-
Context ctx(MakeCUDACtx(0));
366-
auto dmat = RandomDataGenerator{kRows, kCols, 0}.GenerateDMatrix(true);
367-
RegTree expected_tree = GetApproxTree(&ctx, dmat.get());
368-
369-
this->DoTest([&] { VerifyApproxColumnSplit(kRows, kCols, expected_tree); }, true);
370-
this->DoTest([&] { VerifyApproxColumnSplit(kRows, kCols, expected_tree); }, false);
371-
}
372321
} // namespace xgboost::tree
Lines changed: 12 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -1,77 +1,18 @@
1-
from typing import List, cast
1+
"""Copyright 2024, XGBoost contributors"""
22

3-
import numpy as np
4-
from distributed import Client, Scheduler, Worker, get_worker
3+
import pytest
4+
from distributed import Client, Scheduler, Worker
55
from distributed.utils_test import gen_cluster
66

77
import xgboost as xgb
88
from xgboost import testing as tm
9-
from xgboost.compat import concat
10-
11-
12-
def run_external_memory(worker_id: int, n_workers: int, comm_args: dict) -> None:
13-
n_samples_per_batch = 32
14-
n_features = 4
15-
n_batches = 16
16-
use_cupy = False
17-
18-
n_threads = get_worker().state.nthreads
19-
with xgb.collective.CommunicatorContext(dmlc_communicator="rabit", **comm_args):
20-
it = tm.IteratorForTest(
21-
*tm.make_batches(
22-
n_samples_per_batch,
23-
n_features,
24-
n_batches,
25-
use_cupy,
26-
random_state=worker_id,
27-
),
28-
cache="cache",
29-
)
30-
Xy = xgb.DMatrix(it, nthread=n_threads)
31-
results: xgb.callback.TrainingCallback.EvalsLog = {}
32-
booster = xgb.train(
33-
{"tree_method": "hist", "nthread": n_threads},
34-
Xy,
35-
evals=[(Xy, "Train")],
36-
num_boost_round=32,
37-
evals_result=results,
38-
)
39-
assert tm.non_increasing(cast(List[float], results["Train"]["rmse"]))
40-
41-
lx, ly, lw = [], [], []
42-
for i in range(n_workers):
43-
x, y, w = tm.make_batches(
44-
n_samples_per_batch,
45-
n_features,
46-
n_batches,
47-
use_cupy,
48-
random_state=i,
49-
)
50-
lx.extend(x)
51-
ly.extend(y)
52-
lw.extend(w)
53-
54-
X = concat(lx)
55-
yconcat = concat(ly)
56-
wconcat = concat(lw)
57-
Xy = xgb.DMatrix(X, yconcat, weight=wconcat, nthread=n_threads)
58-
59-
results_local: xgb.callback.TrainingCallback.EvalsLog = {}
60-
booster = xgb.train(
61-
{"tree_method": "hist", "nthread": n_threads},
62-
Xy,
63-
evals=[(Xy, "Train")],
64-
num_boost_round=32,
65-
evals_result=results_local,
66-
)
67-
np.testing.assert_allclose(
68-
results["Train"]["rmse"], results_local["Train"]["rmse"], rtol=1e-4
69-
)
9+
from xgboost.testing.dask import check_external_memory
7010

7111

12+
@pytest.mark.parametrize("is_qdm", [True, False])
7213
@gen_cluster(client=True)
7314
async def test_external_memory(
74-
client: Client, s: Scheduler, a: Worker, b: Worker
15+
client: Client, s: Scheduler, a: Worker, b: Worker, is_qdm: bool
7516
) -> None:
7617
workers = tm.get_client_workers(client)
7718
args = await client.sync(
@@ -83,6 +24,11 @@ async def test_external_memory(
8324
n_workers = len(workers)
8425

8526
futs = client.map(
86-
run_external_memory, range(n_workers), n_workers=n_workers, comm_args=args
27+
check_external_memory,
28+
range(n_workers),
29+
n_workers=n_workers,
30+
device="cpu",
31+
comm_args=args,
32+
is_qdm=is_qdm,
8733
)
8834
await client.gather(futs)

tests/test_distributed/test_with_dask/test_with_dask.py

Lines changed: 1 addition & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -7,24 +7,9 @@
77
import socket
88
import tempfile
99
from concurrent.futures import ThreadPoolExecutor
10-
from copy import copy
1110
from functools import partial
12-
from itertools import starmap
13-
from math import ceil
14-
from operator import attrgetter, getitem
1511
from pathlib import Path
16-
from typing import (
17-
Any,
18-
Dict,
19-
Generator,
20-
List,
21-
Literal,
22-
Optional,
23-
Tuple,
24-
Type,
25-
TypeVar,
26-
Union,
27-
)
12+
from typing import Any, Dict, Generator, Literal, Optional, Tuple, Type, Union
2813

2914
import hypothesis
3015
import numpy as np
@@ -37,7 +22,6 @@
3722
import xgboost as xgb
3823
from xgboost import dask as dxgb
3924
from xgboost import testing as tm
40-
from xgboost.data import _is_cudf_df
4125
from xgboost.testing.params import hist_cache_strategy, hist_parameter_strategy
4226
from xgboost.testing.shared import (
4327
get_feature_weights,

0 commit comments

Comments
 (0)