Skip to content

Commit ac0c565

Browse files
committed
Config options for model default parameters
Signed-off-by: Ed Snible <[email protected]>
1 parent 8a1e543 commit ac0c565

File tree

6 files changed

+222
-63
lines changed

6 files changed

+222
-63
lines changed

src/pdl/pdl.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,12 @@
1717
RoleType,
1818
ScopeType,
1919
empty_block_location,
20+
get_default_model_parameters,
2021
)
2122
from .pdl_interpreter import InterpreterState, process_prog
2223
from .pdl_parser import parse_file, parse_str
2324
from .pdl_runner import exec_docker
25+
from .pdl_utils import validate_scope
2426

2527
logger = logging.getLogger(__name__)
2628

@@ -163,7 +165,7 @@ def main():
163165
"-f",
164166
"--data-file",
165167
dest="data_file",
166-
help="file containing initial values to add to the scope",
168+
help="YAML file containing initial values to add to the scope",
167169
)
168170
parser.add_argument(
169171
"-d",
@@ -233,12 +235,15 @@ def main():
233235
exec_docker(*args)
234236
assert False # unreachable: exec_docker terminate the execution
235237

236-
initial_scope = {}
238+
initial_scope = {
239+
"pdl_model_default_parameters": get_default_model_parameters()
240+
}
237241
if args.data_file is not None:
238242
with open(args.data_file, "r", encoding="utf-8") as scope_fp:
239243
initial_scope = yaml.safe_load(scope_fp)
240244
if args.data is not None:
241245
initial_scope = initial_scope | yaml.safe_load(args.data)
246+
validate_scope(initial_scope)
242247

243248
match args.stream:
244249
case "result":

src/pdl/pdl_ast.py

Lines changed: 28 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -542,7 +542,7 @@ def __init__(self, message):
542542

543543
MAX_NEW_TOKENS = 1024
544544
MIN_NEW_TOKENS = 1
545-
REPETITION_PENATLY = 1.05
545+
REPETITION_PENALTY = 1.05
546546
TEMPERATURE_SAMPLING = 0.7
547547
TOP_P_SAMPLING = 0.85
548548
TOP_K_SAMPLING = 50
@@ -568,49 +568,19 @@ def set_structured_decoding_parameters(
568568
return parameters
569569

570570

571-
def set_default_granite_model_parameters(
572-
model_id: str,
573-
spec: Any,
574-
parameters: Optional[dict[str, Any]],
575-
) -> dict[str, Any]:
576-
if parameters is None:
577-
parameters = {}
578-
579-
if "watsonx" in model_id:
580-
if "decoding_method" not in parameters:
581-
parameters["decoding_method"] = (
582-
DECODING_METHOD # pylint: disable=attribute-defined-outside-init
583-
)
584-
if "max_tokens" in parameters and parameters["max_tokens"] is None:
585-
parameters["max_tokens"] = (
586-
MAX_NEW_TOKENS # pylint: disable=attribute-defined-outside-init
587-
)
588-
if "min_new_tokens" not in parameters:
589-
parameters["min_new_tokens"] = (
590-
MIN_NEW_TOKENS # pylint: disable=attribute-defined-outside-init
591-
)
592-
if "repetition_penalty" not in parameters:
593-
parameters["repetition_penalty"] = (
594-
REPETITION_PENATLY # pylint: disable=attribute-defined-outside-init
595-
)
596-
if parameters["decoding_method"] == "sample":
597-
if "temperature" not in parameters:
598-
parameters["temperature"] = (
599-
TEMPERATURE_SAMPLING # pylint: disable=attribute-defined-outside-init
600-
)
601-
if "top_k" not in parameters:
602-
parameters["top_k"] = (
603-
TOP_K_SAMPLING # pylint: disable=attribute-defined-outside-init
604-
)
605-
if "top_p" not in parameters:
606-
parameters["top_p"] = (
607-
TOP_P_SAMPLING # pylint: disable=attribute-defined-outside-init
608-
)
609-
if "replicate" in model_id and "granite-3.0" in model_id:
610-
if "temperature" not in parameters or parameters["temperature"] is None:
611-
parameters["temperature"] = 0 # setting to decoding greedy
612-
if "roles" not in parameters:
613-
parameters["roles"] = {
571+
def get_default_model_parameters() -> list[dict[str, Any]]:
572+
"""Model-specific defaults to apply"""
573+
return [
574+
{ "*watsonx*": {
575+
"decoding_method": DECODING_METHOD,
576+
"max_tokens": MAX_NEW_TOKENS,
577+
"min_new_tokens": MIN_NEW_TOKENS,
578+
"repetition_penalty": REPETITION_PENALTY,
579+
},
580+
},
581+
{ "replicate*granite-3.0*": {
582+
"temperature": 0,
583+
"roles": {
614584
"system": {
615585
"pre_message": "<|start_of_role|>system<|end_of_role|>",
616586
"post_message": "<|end_of_text|>",
@@ -631,10 +601,17 @@ def set_default_granite_model_parameters(
631601
"pre_message": "<|start_of_role|>tool_response<|end_of_role|>",
632602
"post_message": "<|end_of_text|>",
633603
},
634-
}
635-
if "final_prompt_value" not in parameters:
636-
parameters["final_prompt_value"] = (
637-
"<|start_of_role|>assistant<|end_of_role|>"
638-
)
639-
640-
return parameters
604+
},
605+
"final_prompt_value": "<|start_of_role|>assistant<|end_of_role|>"
606+
}
607+
}]
608+
609+
def get_sampling_defaults() -> list[dict[str, Any]]:
610+
"""Model-specific defaults to apply if we are sampling."""
611+
return [
612+
{ "*": {
613+
"temperature": TEMPERATURE_SAMPLING,
614+
"top_k": TOP_K_SAMPLING,
615+
"top_p": TOP_P_SAMPLING,
616+
}
617+
}]

src/pdl/pdl_interpreter.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,7 @@
8989
messages_concat,
9090
replace_contribute_value,
9191
stringify,
92+
apply_defaults,
9293
)
9394

9495
logger = logging.getLogger(__name__)
@@ -1114,6 +1115,11 @@ def step_call_model(
11141115
litellm_params = {}
11151116

11161117
def get_transformed_inputs(kwargs):
1118+
# Apply PDL defaults to model invocation
1119+
kwargs['optional_params'] = apply_defaults(kwargs['model'],
1120+
kwargs['optional_params'],
1121+
scope['pdl_model_default_parameters'])
1122+
11171123
params_to_model = kwargs["additional_args"]["complete_input_dict"]
11181124
nonlocal litellm_params
11191125
litellm_params = params_to_model

src/pdl/pdl_llms.py

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77

88
from .pdl_ast import (
99
Message,
10-
set_default_granite_model_parameters,
1110
set_structured_decoding_parameters,
1211
)
1312
from .pdl_utils import remove_none_values_from_message
@@ -38,10 +37,6 @@ def generate_text(
3837
spec: Any,
3938
parameters: dict[str, Any],
4039
) -> tuple[Message, Any]:
41-
if "granite" in model_id and "granite-20b-code-instruct-r1.1" not in model_id:
42-
parameters = set_default_granite_model_parameters(
43-
model_id, spec, parameters
44-
)
4540
parameters = set_structured_decoding_parameters(spec, parameters)
4641
if parameters.get("mock_response") is not None:
4742
litellm.suppress_debug_info = True
@@ -63,10 +58,6 @@ def generate_text_stream(
6358
spec: Any,
6459
parameters: dict[str, Any],
6560
) -> Generator[Message, Any, Any]:
66-
if "granite" in model_id and "granite-20b-code-instruct-r1.1" not in model_id:
67-
parameters = set_default_granite_model_parameters(
68-
model_id, spec, parameters
69-
)
7061
parameters = set_structured_decoding_parameters(spec, parameters)
7162
response = completion(
7263
model=model_id,

src/pdl/pdl_utils.py

Lines changed: 66 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
1+
import fnmatch
12
import json
23
from typing import Any, Sequence
34

4-
from .pdl_ast import ContributeTarget, ContributeValue, FunctionBlock, Message, Messages
5+
from .pdl_ast import ContributeTarget, ContributeValue, FunctionBlock, Message, Messages, get_sampling_defaults
56

67

78
def stringify(result):
@@ -83,3 +84,67 @@ def remove_none_values_from_message(message: Any) -> dict[str, Any]:
8384
else:
8485
ret[key] = value
8586
return ret
87+
88+
def apply_defaults(model_id: str, params: dict[str, Any], all_model_defaults: list[dict[str, dict[str, Any]]]) -> dict[str, Any]:
89+
# Never apply defaults to granite-20b-code-instruct-r1.1
90+
if "granite-20b-code-instruct-r1.1" in model_id:
91+
return params
92+
93+
parameters = apply_raw_defaults(model_id, params, all_model_defaults)
94+
95+
if "decoding_method" in parameters and parameters["decoding_method"] == "sample":
96+
parameters = apply_raw_defaults(model_id, parameters, get_sampling_defaults())
97+
98+
return parameters
99+
100+
def apply_raw_defaults(model_id: str, params: dict[str, Any], model_defaults: list[dict[str, dict[str, Any]]]) -> dict[str, Any]:
101+
"""Apply defaults to params based on a list of model defaults
102+
103+
Args:
104+
model_id: A PDL model ID
105+
params: The explicit parameters set by in PDL
106+
model_defaults: A list of dicts, where the keys are globs for model id, and the value is a dict of defaults
107+
108+
Returns:
109+
The parameters to send to the LLM
110+
"""
111+
112+
assert isinstance(model_id, str)
113+
assert isinstance(params, dict)
114+
assert isinstance(model_defaults, list)
115+
116+
# Construct defaults for this model. If more than one set of default
117+
# applies, the last seen default "wins".
118+
default_union = {}
119+
for model_default in model_defaults:
120+
assert isinstance(model_default, dict)
121+
for model_glob, glob_defaults in model_default.items():
122+
if not isinstance(glob_defaults, dict):
123+
raise ValueError(f"invalid default type {type(glob_defaults)} for model matcher {model_glob}")
124+
assert isinstance(glob_defaults, dict)
125+
if fnmatch.fnmatchcase(model_id, model_glob):
126+
print(f"model {model_id} matches {model_glob}")
127+
for k, v in glob_defaults.items():
128+
default_union[k] = v
129+
130+
# Apply final list of defaults to explicit parameters
131+
retval = dict(params)
132+
for k, v in default_union.items():
133+
if k not in retval or retval[k] is None:
134+
retval[k] = v
135+
return retval
136+
137+
def validate_scope(scope: dict):
138+
"""Throw an exception if any key in scope is invalid"""
139+
validate_pdl_model_defaults(scope["pdl_model_default_parameters"])
140+
141+
def validate_pdl_model_defaults(model_defaults: list[dict[str, dict[str, Any]]]):
142+
"""Throw an exception if the model_defaults is not in expected format"""
143+
144+
errors = False
145+
for model_default in model_defaults:
146+
assert isinstance(model_default, dict)
147+
for model_glob, glob_defaults in model_default.items():
148+
if not isinstance(glob_defaults, dict):
149+
raise ValueError(f"invalid defaults {glob_defaults} for model matcher {model_glob}")
150+
assert isinstance(glob_defaults, dict)

tests/test_defaults.py

Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,115 @@
1+
from pdl.pdl_utils import apply_defaults
2+
from pdl.pdl_ast import get_default_model_parameters
3+
4+
def test_default_model_params_empty():
5+
params = apply_defaults("replicate/ibm-granite/granite-20b-code-instruct-8k",
6+
{},
7+
[])
8+
assert {} == params
9+
10+
def test_default_model_params_nomatch():
11+
params = apply_defaults("replicate/ibm-granite/granite-20b-code-instruct-8k",
12+
{},
13+
[
14+
{"dummy": {"foo": "bar"}}
15+
])
16+
assert {} == params
17+
18+
def test_default_model_params_exact_match():
19+
params = apply_defaults("replicate/ibm-granite/granite-20b-code-instruct-8k",
20+
{
21+
"foo": "baz"
22+
},
23+
[
24+
{"replicate/ibm-granite/granite-20b-code-instruct-8k":
25+
{
26+
"foo": "bar",
27+
"max_tokens": 9999,
28+
}
29+
}
30+
])
31+
assert {
32+
"foo": "baz",
33+
"max_tokens": 9999
34+
} == params
35+
36+
def test_default_model_params_partial_matches():
37+
params = apply_defaults("replicate/ibm-granite/granite-20b-code-instruct-8k",
38+
{
39+
"foo": "baz"
40+
},
41+
[
42+
{
43+
"*granite*":
44+
{
45+
"foo": "bar",
46+
"max_tokens": 9999,
47+
},
48+
},{
49+
"*instruct-8k*":
50+
{
51+
"fruit": "banana",
52+
"max_tokens": 777,
53+
}
54+
},{
55+
"*destruct-401k*":
56+
{
57+
"vegetable": "carrot",
58+
"max_tokens": 888,
59+
}
60+
}
61+
])
62+
assert {
63+
"foo": "baz",
64+
"fruit": "banana",
65+
"max_tokens": 777
66+
} == params
67+
68+
def test_default_model_params():
69+
model_defaults = get_default_model_parameters()
70+
# No defaults for this model
71+
params = apply_defaults("replicate/ibm-granite/granite-20b-code-instruct-8k",
72+
{},
73+
model_defaults)
74+
assert {
75+
} == params
76+
77+
# Granite-3.0 defaults for this model
78+
params = apply_defaults("replicate/ibm-granite/granite-3.0-8b-instruct",
79+
{},
80+
model_defaults)
81+
assert {
82+
"temperature": 0,
83+
"roles": {
84+
"system": {
85+
"pre_message": "<|start_of_role|>system<|end_of_role|>",
86+
"post_message": "<|end_of_text|>",
87+
},
88+
"user": {
89+
"pre_message": "<|start_of_role|>user<|end_of_role|>",
90+
"post_message": "<|end_of_text|>",
91+
},
92+
"assistant": {
93+
"pre_message": "<|start_of_role|>assistant<|end_of_role|>",
94+
"post_message": "<|end_of_text|>",
95+
},
96+
"available_tools": {
97+
"pre_message": "<|start_of_role|>available_tools<|end_of_role|>",
98+
"post_message": "<|end_of_text|>",
99+
},
100+
"tool_response": {
101+
"pre_message": "<|start_of_role|>tool_response<|end_of_role|>",
102+
"post_message": "<|end_of_text|>",
103+
},
104+
},
105+
"final_prompt_value": "<|start_of_role|>assistant<|end_of_role|>"
106+
} == params
107+
108+
def test_default_not_granite_20b_code_instruct_r1_1():
109+
# Don't apply defaults to granite-20b-code-instruct-r1.1
110+
model_defaults = get_default_model_parameters()
111+
params = apply_defaults("replicate/ibm-granite/granite-20b-code-instruct-r1.1",
112+
{},
113+
model_defaults)
114+
assert {
115+
} == params

0 commit comments

Comments
 (0)