Skip to content

Commit aa3ee0e

Browse files
authored
model-conversion : add embedding prompt file support (ggml-org#15871)
This commit adds support for passing a prompt file to the model conversion targets/scripts. It also updates the logits.cpp to print out embedding information in the same format as when running the original embedding model. The motivation for this is that it allows us to pass files of different sizes when running the converted models and validating the logits. This can be particularly important when testing the sliding window functionality of models where the sequence length needs to exceed a certain number of tokens to trigger the sliding window logic.
1 parent d0991da commit aa3ee0e

File tree

7 files changed

+187
-34
lines changed

7 files changed

+187
-34
lines changed

examples/model-conversion/Makefile

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -118,13 +118,17 @@ embedding-convert-model:
118118

119119
embedding-run-original-model:
120120
$(call validate_embedding_model_path,embedding-run-original-model)
121-
@EMBEDDING_MODEL_PATH="$(EMBEDDING_MODEL_PATH)" ./scripts/embedding/run-original-model.py
121+
@EMBEDDING_MODEL_PATH="$(EMBEDDING_MODEL_PATH)" \
122+
./scripts/embedding/run-original-model.py \
123+
$(if $(PROMPTS_FILE),--prompts-file "$(PROMPTS_FILE)")
122124

123125
embedding-run-converted-model:
124-
@CONVERTED_EMBEDDING_MODEL="$(CONVERTED_EMBEDDING_MODEL)" ./scripts/embedding/run-converted-model.sh ${CONVERTED_EMBEDDING_MODEL}
126+
@./scripts/embedding/run-converted-model.sh $(CONVERTED_EMBEDDING_MODEL) \
127+
$(if $(PROMPTS_FILE),--prompts-file "$(PROMPTS_FILE)")
125128

126129
embedding-verify-logits: embedding-run-original-model embedding-run-converted-model
127-
@./scripts/embedding/compare-embeddings-logits.sh
130+
@./scripts/embedding/compare-embeddings-logits.sh \
131+
$(if $(PROMPTS_FILE),--prompts-file "$(PROMPTS_FILE)")
128132

129133
embedding-inspect-original-model:
130134
$(call validate_embedding_model_path,embedding-inspect-original-model)
@@ -156,7 +160,8 @@ embedding-quantize-model:
156160
$(call quantize_model,$(CONVERTED_EMBEDDING_MODEL),QUANTIZED_EMBEDDING_MODEL)
157161

158162
embedding-run-quantized-model:
159-
@./scripts/embedding/run-converted-model.sh ${QUANTIZED_EMBEDDING_MODEL}
163+
@./scripts/embedding/run-converted-model.sh $(QUANTIZED_EMBEDDING_MODEL) \
164+
$(if $(PROMPTS_FILE),--prompts-file "$(PROMPTS_FILE)")
160165

161166
###
162167
### Perplexity targets/recipes

examples/model-conversion/logits.cpp

Lines changed: 41 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,35 @@ int main(int argc, char ** argv) {
151151
logits = llama_get_embeddings(ctx);
152152
n_logits = llama_model_n_embd(model) * batch.n_tokens;
153153
type = "-embeddings";
154+
155+
const int n_embd = llama_model_n_embd(model);
156+
const int n_embd_count = batch.n_tokens;
157+
158+
printf("Embedding dimension: %d\n", n_embd);
159+
printf("\n");
160+
161+
// Print embeddings in the specified format
162+
for (int j = 0; j < n_embd_count; j++) {
163+
printf("embedding %d: ", j);
164+
165+
// Print first 3 values
166+
for (int i = 0; i < 3 && i < n_embd; i++) {
167+
printf("%9.6f ", logits[j * n_embd + i]);
168+
}
169+
170+
printf(" ... ");
171+
172+
// Print last 3 values
173+
for (int i = n_embd - 3; i < n_embd; i++) {
174+
if (i >= 0) {
175+
printf("%9.6f ", logits[j * n_embd + i]);
176+
}
177+
}
178+
179+
printf("\n");
180+
}
181+
printf("\n");
182+
154183
printf("Embeddings size: %d\n", n_logits);
155184
} else {
156185
logits = llama_get_logits_ith(ctx, batch.n_tokens - 1);
@@ -183,22 +212,23 @@ int main(int argc, char ** argv) {
183212
return 1;
184213
}
185214
for (int i = 0; i < n_logits; i++) {
186-
fprintf(f, "%d: %.6f\n", i, logits[i]); // Added index and changed format
215+
fprintf(f, "%d: %.6f\n", i, logits[i]);
187216
}
188217
fclose(f);
189218

190-
// Print first and last 10 logits for quick verification
191-
printf("First 10 logits: ");
192-
for (int i = 0; i < 10 && i < n_logits; i++) {
193-
printf("%.6f ", logits[i]);
194-
}
195-
printf("\n");
219+
if (!embedding_mode) {
220+
printf("First 10 logits: ");
221+
for (int i = 0; i < 10 && i < n_logits; i++) {
222+
printf("%.6f ", logits[i]);
223+
}
224+
printf("\n");
196225

197-
printf("Last 10 logits: ");
198-
for (int i = n_logits - 10; i < n_logits; i++) {
199-
if (i >= 0) printf("%.6f ", logits[i]);
226+
printf("Last 10 logits: ");
227+
for (int i = n_logits - 10; i < n_logits; i++) {
228+
if (i >= 0) printf("%.6f ", logits[i]);
229+
}
230+
printf("\n\n");
200231
}
201-
printf("\n\n");
202232

203233
printf("Logits saved to %s\n", bin_filename);
204234
printf("Logits saved to %s\n", txt_filename);

examples/model-conversion/scripts/embedding/compare-embeddings-logits.sh

Lines changed: 44 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,37 @@
22

33
set -e
44

5-
MODEL_PATH="${1:-"$EMBEDDING_MODEL_PATH"}"
6-
MODEL_NAME="${2:-$(basename "$MODEL_PATH")}"
5+
# Parse command line arguments
6+
MODEL_PATH=""
7+
MODEL_NAME=""
8+
PROMPTS_FILE=""
9+
10+
# First argument is always model path
11+
if [ $# -gt 0 ] && [[ "$1" != --* ]]; then
12+
MODEL_PATH="$1"
13+
shift
14+
fi
15+
16+
# Parse remaining arguments
17+
while [[ $# -gt 0 ]]; do
18+
case $1 in
19+
--prompts-file|-pf)
20+
PROMPTS_FILE="$2"
21+
shift 2
22+
;;
23+
*)
24+
# If MODEL_NAME not set and this isn't a flag, use as model name
25+
if [ -z "$MODEL_NAME" ] && [[ "$1" != --* ]]; then
26+
MODEL_NAME="$1"
27+
fi
28+
shift
29+
;;
30+
esac
31+
done
32+
33+
# Set defaults
34+
MODEL_PATH="${MODEL_PATH:-"$EMBEDDING_MODEL_PATH"}"
35+
MODEL_NAME="${MODEL_NAME:-$(basename "$MODEL_PATH")}"
736

837
if [ -t 0 ]; then
938
CPP_EMBEDDINGS="data/llamacpp-${MODEL_NAME}-embeddings.bin"
@@ -35,8 +64,18 @@ with open('$TEMP_FILE', 'wb') as f:
3564
trap "rm -f $TEMP_FILE" EXIT
3665
fi
3766

38-
python scripts/utils/semantic_check.py --model-path $MODEL_PATH \
67+
# Build the semantic_check.py command
68+
SEMANTIC_CMD="python scripts/utils/semantic_check.py --model-path $MODEL_PATH \
3969
--python-embeddings data/pytorch-${MODEL_NAME}-embeddings.bin \
40-
--cpp-embeddings $CPP_EMBEDDINGS \
41-
--prompt "Hello world today"
70+
--cpp-embeddings $CPP_EMBEDDINGS"
71+
72+
# Add prompts file if specified, otherwise use default prompt
73+
if [ -n "$PROMPTS_FILE" ]; then
74+
SEMANTIC_CMD="$SEMANTIC_CMD --prompts-file \"$PROMPTS_FILE\""
75+
else
76+
SEMANTIC_CMD="$SEMANTIC_CMD --prompt \"Hello world today\""
77+
fi
78+
79+
# Execute the command
80+
eval $SEMANTIC_CMD
4281

examples/model-conversion/scripts/embedding/run-converted-model.sh

Lines changed: 34 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,27 @@
22

33
set -e
44

5-
# First try command line argument, then environment variable, then file
6-
CONVERTED_MODEL="${1:-"$CONVERTED_EMBEDDING_MODEL"}"
5+
# Parse command line arguments
6+
CONVERTED_MODEL=""
7+
PROMPTS_FILE=""
8+
9+
while [[ $# -gt 0 ]]; do
10+
case $1 in
11+
-p|--prompts-file)
12+
PROMPTS_FILE="$2"
13+
shift 2
14+
;;
15+
*)
16+
if [ -z "$CONVERTED_MODEL" ]; then
17+
CONVERTED_MODEL="$1"
18+
fi
19+
shift
20+
;;
21+
esac
22+
done
23+
24+
# First try command line argument, then environment variable
25+
CONVERTED_MODEL="${CONVERTED_MODEL:-"$CONVERTED_EMBEDDING_MODEL"}"
726

827
# Final check if we have a model path
928
if [ -z "$CONVERTED_MODEL" ]; then
@@ -13,8 +32,19 @@ if [ -z "$CONVERTED_MODEL" ]; then
1332
exit 1
1433
fi
1534

35+
# Read prompt from file or use default
36+
if [ -n "$PROMPTS_FILE" ]; then
37+
if [ ! -f "$PROMPTS_FILE" ]; then
38+
echo "Error: Prompts file '$PROMPTS_FILE' not found" >&2
39+
exit 1
40+
fi
41+
PROMPT=$(cat "$PROMPTS_FILE")
42+
else
43+
PROMPT="Hello world today"
44+
fi
45+
1646
echo $CONVERTED_MODEL
1747

1848
cmake --build ../../build --target llama-logits -j8
19-
20-
../../build/bin/llama-logits -m "$CONVERTED_MODEL" -embd-mode "Hello world today"
49+
# TODO: update logits.cpp to accept a --file/-f option for the prompt
50+
../../build/bin/llama-logits -m "$CONVERTED_MODEL" -embd-mode "$PROMPT"

examples/model-conversion/scripts/embedding/run-original-model.py

Lines changed: 37 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,14 +13,37 @@
1313

1414
parser = argparse.ArgumentParser(description='Process model with specified path')
1515
parser.add_argument('--model-path', '-m', help='Path to the model')
16+
parser.add_argument('--prompts-file', '-p', help='Path to file containing prompts (one per line)')
1617
args = parser.parse_args()
1718

19+
def read_prompt_from_file(file_path):
20+
try:
21+
with open(file_path, 'r', encoding='utf-8') as f:
22+
return f.read().strip()
23+
except FileNotFoundError:
24+
print(f"Error: Prompts file '{file_path}' not found")
25+
exit(1)
26+
except Exception as e:
27+
print(f"Error reading prompts file: {e}")
28+
exit(1)
29+
1830
model_path = os.environ.get('EMBEDDING_MODEL_PATH', args.model_path)
1931
if model_path is None:
2032
parser.error("Model path must be specified either via --model-path argument or EMBEDDING_MODEL_PATH environment variable")
2133

2234
tokenizer = AutoTokenizer.from_pretrained(model_path)
2335

36+
config = AutoConfig.from_pretrained(model_path)
37+
38+
# This can be used to override the sliding window size for manual testing. This
39+
# can be useful to verify the sliding window attention mask in the original model
40+
# and compare it with the converted .gguf model.
41+
if hasattr(config, 'sliding_window'):
42+
original_sliding_window = config.sliding_window
43+
#original_sliding_window = 6
44+
print(f"Modified sliding window: {original_sliding_window} -> {config.sliding_window}")
45+
46+
print(f"Using unreleased model: {unreleased_model_name}")
2447
if unreleased_model_name:
2548
model_name_lower = unreleased_model_name.lower()
2649
unreleased_module_path = f"transformers.models.{model_name_lower}.modular_{model_name_lower}"
@@ -29,19 +52,28 @@
2952

3053
try:
3154
model_class = getattr(importlib.import_module(unreleased_module_path), class_name)
32-
model = model_class.from_pretrained(model_path) # Note: from_pretrained, not fromPretrained
55+
model = model_class.from_pretrained(model_path, config=config)
3356
except (ImportError, AttributeError) as e:
3457
print(f"Failed to import or load model: {e}")
3558
exit(1)
3659
else:
37-
model = AutoModel.from_pretrained(model_path)
60+
model = AutoModel.from_pretrained(model_path, config=config)
3861
print(f"Model class: {type(model)}")
39-
#print(f"Model file: {type(model).__module__}")
40-
config = AutoConfig.from_pretrained(model_path)
62+
print(f"Model file: {type(model).__module__}")
63+
64+
# Verify the model is using the correct sliding window
65+
if hasattr(model.config, 'sliding_window'):
66+
print(f"Model's sliding_window: {model.config.sliding_window}")
67+
else:
68+
print("Model config does not have sliding_window attribute")
4169

4270
model_name = os.path.basename(model_path)
4371

44-
texts = [ "Hello world today" ]
72+
if args.prompts_file:
73+
prompt_text = read_prompt_from_file(args.prompts_file)
74+
texts = [prompt_text]
75+
else:
76+
texts = ["Hello world today"]
4577

4678
encoded = tokenizer(
4779
texts,

examples/model-conversion/scripts/utils/inspect-org-model.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@
4040
file_path = os.path.join(model_path, file_name)
4141
print(f"\n--- From {file_name} ---")
4242

43-
with safe_open(file_path, framework="pt") as f: # type: ignore
43+
with safe_open(file_path, framework="pt") as f:
4444
for tensor_name in sorted(tensor_names):
4545
tensor = f.get_tensor(tensor_name)
4646
print(f"- {tensor_name} : shape = {tensor.shape}, dtype = {tensor.dtype}")
@@ -49,7 +49,7 @@
4949
# Single file model (original behavior)
5050
print("Single-file model detected")
5151

52-
with safe_open(single_file_path, framework="pt") as f: # type: ignore
52+
with safe_open(single_file_path, framework="pt") as f:
5353
keys = f.keys()
5454
print("Tensors in model:")
5555
for key in sorted(keys):

examples/model-conversion/scripts/utils/semantic_check.py

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -101,21 +101,38 @@ def test_single_prompt_similarity(python_emb, cpp_emb, tokens, prompt):
101101
'rms_diff': np.sqrt(np.mean(diff_matrix**2))
102102
}
103103

104+
def read_prompt_from_file(file_path):
105+
try:
106+
with open(file_path, 'r', encoding='utf-8') as f:
107+
return f.read().strip()
108+
except FileNotFoundError:
109+
print(f"Error: Prompts file '{file_path}' not found")
110+
exit(1)
111+
except Exception as e:
112+
print(f"Error reading prompts file: {e}")
113+
exit(1)
114+
104115
def main():
105116
parser = argparse.ArgumentParser(description='Test semantic similarity between Python and llama.cpp embeddings')
106117
parser.add_argument('--model-path', '-m', required=True, help='Path to the original Python model')
107118
parser.add_argument('--python-embeddings', '-pe', help='Path to pytorch embeddings "logits" binary file')
108119
parser.add_argument('--cpp-embeddings', '-ce', help='Path to llama.cpp embeddings "logits" binary file')
109120
parser.add_argument('--causal', '-c', default=False, help='if the model is causal (default: false)', action='store_true')
110121
parser.add_argument('--prompt', '-p', default='Hello world today', help='Test prompt')
122+
parser.add_argument('--prompts-file', '-pf', help='Path to file containing prompts')
111123

112124
args = parser.parse_args()
113125

126+
if args.prompts_file:
127+
prompt = read_prompt_from_file(args.prompts_file)
128+
else:
129+
prompt = args.prompt
130+
114131
print("Semantic Similarity Test Between Python and llama.cpp Embedding Models")
115132
print("=" * 70)
116133

117134
# Single prompt detailed comparison
118-
print(f"\nTesting with prompt: '{args.prompt}'")
135+
print(f"\nTesting with prompt: '{prompt}'")
119136

120137
# Load the python model to get configuration information and also to load the tokenizer.
121138
print("Loading model and tokenizer using AutoTokenizer:", args.model_path)
@@ -144,7 +161,7 @@ def main():
144161
else:
145162
model = AutoModel.from_pretrained(args.model_path)
146163

147-
encoded = tokenizer(args.prompt, return_tensors="pt")
164+
encoded = tokenizer(prompt, return_tensors="pt")
148165
tokens = tokenizer.convert_ids_to_tokens(encoded['input_ids'][0])
149166
n_tokens = len(tokens)
150167
print(f"n_tokens: {n_tokens}");
@@ -155,7 +172,7 @@ def main():
155172
python_embeddings = load_embeddings_from_file(args.python_embeddings, n_tokens, model.config.hidden_size)
156173

157174
# Run comparison
158-
results = test_single_prompt_similarity(python_embeddings, llamacpp_embeddings, tokens, args.prompt)
175+
results = test_single_prompt_similarity(python_embeddings, llamacpp_embeddings, tokens, prompt)
159176

160177
# Summary
161178
print(f"\n=== SUMMARY ===")

0 commit comments

Comments
 (0)