|
7 | 7 | import numpy as np
|
8 | 8 | import torch
|
9 | 9 | import transformers
|
10 |
| -from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, AutoConfig |
11 |
| -from accelerate import infer_auto_device_map, init_empty_weights, load_checkpoint_and_dispatch |
| 10 | +from accelerate import infer_auto_device_map, init_empty_weights |
| 11 | +from transformers import (AutoConfig, AutoModelForCausalLM, AutoTokenizer, |
| 12 | + BitsAndBytesConfig) |
12 | 13 |
|
13 | 14 | import modules.shared as shared
|
14 | 15 |
|
@@ -113,23 +114,20 @@ def load_model(model_name):
|
113 | 114 |
|
114 | 115 | if shared.args.gpu_memory:
|
115 | 116 | memory_map = shared.args.gpu_memory
|
116 |
| - max_memory = { 0: f'{memory_map[0]}GiB' } |
117 |
| - for i in range(1, len(memory_map)): |
| 117 | + max_memory = {} |
| 118 | + for i in range(len(memory_map)): |
118 | 119 | max_memory[i] = f'{memory_map[i]}GiB'
|
119 | 120 | max_memory['cpu'] = f'{shared.args.cpu_memory or 99}GiB'
|
120 | 121 | params['max_memory'] = max_memory
|
121 | 122 | else:
|
122 |
| - total_mem = (torch.cuda.get_device_properties(0).total_memory / (1024 * 1024)) |
123 |
| - suggestion = round((total_mem - 1000) / 1000) * 1000 |
| 123 | + total_mem = (torch.cuda.get_device_properties(0).total_memory / (1024*1024)) |
| 124 | + suggestion = round((total_mem-1000) / 1000) * 1000 |
124 | 125 | if total_mem - suggestion < 800:
|
125 | 126 | suggestion -= 1000
|
126 | 127 | suggestion = int(round(suggestion/1000))
|
127 | 128 | print(f"\033[1;32;1mAuto-assiging --gpu-memory {suggestion} for your GPU to try to prevent out-of-memory errors.\nYou can manually set other values.\033[0;37;0m")
|
128 | 129 |
|
129 |
| - max_memory = { |
130 |
| - 0: f'{suggestion}GiB', |
131 |
| - 'cpu': f'{shared.args.cpu_memory or 99}GiB' |
132 |
| - } |
| 130 | + max_memory = {0: f'{suggestion}GiB', 'cpu': f'{shared.args.cpu_memory or 99}GiB'} |
133 | 131 | params['max_memory'] = max_memory
|
134 | 132 |
|
135 | 133 | if shared.args.disk:
|
|
0 commit comments