Skip to content

Commit b61aef4

Browse files
committed
fix(gpu-detection): default to CPU if there is less than 4GB of GPU available
Signed-off-by: Ettore Di Giacinto <[email protected]>
1 parent 2b44467 commit b61aef4

File tree

2 files changed

+47
-17
lines changed

2 files changed

+47
-17
lines changed

core/gallery/backends_test.go

Lines changed: 29 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -105,10 +105,11 @@ var _ = Describe("Gallery Backends", func() {
105105
Name: "meta-backend",
106106
},
107107
CapabilitiesMap: map[string]string{
108-
"nvidia": "nvidia-backend",
109-
"amd": "amd-backend",
110-
"intel": "intel-backend",
111-
"metal": "metal-backend",
108+
"nvidia": "nvidia-backend",
109+
"amd": "amd-backend",
110+
"intel": "intel-backend",
111+
"metal": "metal-backend",
112+
"default": "default-backend",
112113
},
113114
}
114115

@@ -133,7 +134,14 @@ var _ = Describe("Gallery Backends", func() {
133134
URI: testImage,
134135
}
135136

136-
backends := GalleryElements[*GalleryBackend]{nvidiaBackend, amdBackend, metalBackend}
137+
defaultBackend := &GalleryBackend{
138+
Metadata: Metadata{
139+
Name: "default-backend",
140+
},
141+
URI: testImage,
142+
}
143+
144+
backends := GalleryElements[*GalleryBackend]{nvidiaBackend, amdBackend, metalBackend, defaultBackend}
137145

138146
if runtime.GOOS == "darwin" && runtime.GOARCH == "arm64" {
139147
metal := &system.SystemState{}
@@ -142,15 +150,26 @@ var _ = Describe("Gallery Backends", func() {
142150

143151
} else {
144152
// Test with NVIDIA system state
145-
nvidiaSystemState := &system.SystemState{GPUVendor: "nvidia"}
153+
nvidiaSystemState := &system.SystemState{GPUVendor: "nvidia", VRAM: 1000000000000}
146154
bestBackend := metaBackend.FindBestBackendFromMeta(nvidiaSystemState, backends)
147155
Expect(bestBackend).To(Equal(nvidiaBackend))
148156

149157
// Test with AMD system state
150-
amdSystemState := &system.SystemState{GPUVendor: "amd"}
158+
amdSystemState := &system.SystemState{GPUVendor: "amd", VRAM: 1000000000000}
151159
bestBackend = metaBackend.FindBestBackendFromMeta(amdSystemState, backends)
152160
Expect(bestBackend).To(Equal(amdBackend))
153161

162+
// Test with default system state (not enough VRAM)
163+
defaultSystemState := &system.SystemState{GPUVendor: "amd"}
164+
bestBackend = metaBackend.FindBestBackendFromMeta(defaultSystemState, backends)
165+
Expect(bestBackend).To(Equal(defaultBackend))
166+
167+
// Test with default system state
168+
defaultSystemState = &system.SystemState{GPUVendor: "default"}
169+
bestBackend = metaBackend.FindBestBackendFromMeta(defaultSystemState, backends)
170+
Expect(bestBackend).To(Equal(defaultBackend))
171+
172+
backends = GalleryElements[*GalleryBackend]{nvidiaBackend, amdBackend, metalBackend}
154173
// Test with unsupported GPU vendor
155174
unsupportedSystemState := &system.SystemState{GPUVendor: "unsupported"}
156175
bestBackend = metaBackend.FindBestBackendFromMeta(unsupportedSystemState, backends)
@@ -201,7 +220,7 @@ var _ = Describe("Gallery Backends", func() {
201220
Expect(err).NotTo(HaveOccurred())
202221

203222
// Test with NVIDIA system state
204-
nvidiaSystemState := &system.SystemState{GPUVendor: "nvidia"}
223+
nvidiaSystemState := &system.SystemState{GPUVendor: "nvidia", VRAM: 1000000000000}
205224
err = InstallBackendFromGallery([]config.Gallery{gallery}, nvidiaSystemState, "meta-backend", tempDir, nil, true)
206225
Expect(err).NotTo(HaveOccurred())
207226

@@ -272,7 +291,7 @@ var _ = Describe("Gallery Backends", func() {
272291
Expect(err).NotTo(HaveOccurred())
273292

274293
// Test with NVIDIA system state
275-
nvidiaSystemState := &system.SystemState{GPUVendor: "nvidia"}
294+
nvidiaSystemState := &system.SystemState{GPUVendor: "nvidia", VRAM: 1000000000000}
276295
err = InstallBackendFromGallery([]config.Gallery{gallery}, nvidiaSystemState, "meta-backend", tempDir, nil, true)
277296
Expect(err).NotTo(HaveOccurred())
278297

@@ -344,7 +363,7 @@ var _ = Describe("Gallery Backends", func() {
344363
Expect(err).NotTo(HaveOccurred())
345364

346365
// Test with NVIDIA system state
347-
nvidiaSystemState := &system.SystemState{GPUVendor: "nvidia"}
366+
nvidiaSystemState := &system.SystemState{GPUVendor: "nvidia", VRAM: 1000000000000}
348367
err = InstallBackendFromGallery([]config.Gallery{gallery}, nvidiaSystemState, "meta-backend", tempDir, nil, true)
349368
Expect(err).NotTo(HaveOccurred())
350369

pkg/system/capabilities.go

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,15 @@ import (
55
"runtime"
66
"strings"
77

8+
"github.com/jaypipes/ghw/pkg/gpu"
89
"github.com/mudler/LocalAI/pkg/xsysinfo"
910
"github.com/rs/zerolog/log"
1011
)
1112

1213
type SystemState struct {
1314
GPUVendor string
15+
gpus []*gpu.GraphicsCard
16+
VRAM uint64
1417
}
1518

1619
const (
@@ -91,24 +94,32 @@ func (s *SystemState) getSystemCapabilities() string {
9194
}
9295

9396
log.Info().Str("Capability", s.GPUVendor).Msgf("Capability automatically detected, set %s to override", capabilityEnv)
97+
// If vram is less than 4GB, let's default to CPU but warn the user that they can override that via env
98+
if s.VRAM <= 4*1024*1024*1024 {
99+
log.Warn().Msgf("VRAM is less than 4GB, defaulting to CPU. Set %s to override", capabilityEnv)
100+
return defaultCapability
101+
}
102+
94103
return s.GPUVendor
95104
}
96105

97106
func GetSystemState() (*SystemState, error) {
98-
gpuVendor, _ := detectGPUVendor()
107+
// Detection is best-effort here, we don't want to fail if it fails
108+
gpus, _ := xsysinfo.GPUs()
109+
log.Debug().Any("gpus", gpus).Msg("GPUs")
110+
gpuVendor, _ := detectGPUVendor(gpus)
99111
log.Debug().Str("gpuVendor", gpuVendor).Msg("GPU vendor")
112+
vram, _ := xsysinfo.TotalAvailableVRAM()
113+
log.Debug().Any("vram", vram).Msg("Total available VRAM")
100114

101115
return &SystemState{
102116
GPUVendor: gpuVendor,
117+
gpus: gpus,
118+
VRAM: vram,
103119
}, nil
104120
}
105121

106-
func detectGPUVendor() (string, error) {
107-
gpus, err := xsysinfo.GPUs()
108-
if err != nil {
109-
return "", err
110-
}
111-
122+
func detectGPUVendor(gpus []*gpu.GraphicsCard) (string, error) {
112123
for _, gpu := range gpus {
113124
if gpu.DeviceInfo != nil {
114125
if gpu.DeviceInfo.Vendor != nil {

0 commit comments

Comments
 (0)