|
| 1 | +import fnmatch |
1 | 2 | import json
|
2 | 3 | from typing import Any, Sequence
|
3 | 4 |
|
4 |
| -from .pdl_ast import ContributeTarget, ContributeValue, FunctionBlock, Message, Messages |
| 5 | +from .pdl_ast import ContributeTarget, ContributeValue, FunctionBlock, Message, Messages, get_sampling_defaults |
5 | 6 |
|
6 | 7 |
|
7 | 8 | def stringify(result):
|
@@ -83,3 +84,67 @@ def remove_none_values_from_message(message: Any) -> dict[str, Any]:
|
83 | 84 | else:
|
84 | 85 | ret[key] = value
|
85 | 86 | return ret
|
| 87 | + |
| 88 | +def apply_defaults(model_id: str, params: dict[str, Any], all_model_defaults: list[dict[str, dict[str, Any]]]) -> dict[str, Any]: |
| 89 | + # Never apply defaults to granite-20b-code-instruct-r1.1 |
| 90 | + if "granite-20b-code-instruct-r1.1" in model_id: |
| 91 | + return params |
| 92 | + |
| 93 | + parameters = apply_raw_defaults(model_id, params, all_model_defaults) |
| 94 | + |
| 95 | + if "decoding_method" in parameters and parameters["decoding_method"] == "sample": |
| 96 | + parameters = apply_raw_defaults(model_id, parameters, get_sampling_defaults()) |
| 97 | + |
| 98 | + return parameters |
| 99 | + |
| 100 | +def apply_raw_defaults(model_id: str, params: dict[str, Any], model_defaults: list[dict[str, dict[str, Any]]]) -> dict[str, Any]: |
| 101 | + """Apply defaults to params based on a list of model defaults |
| 102 | +
|
| 103 | + Args: |
| 104 | + model_id: A PDL model ID |
| 105 | + params: The explicit parameters set by in PDL |
| 106 | + model_defaults: A list of dicts, where the keys are globs for model id, and the value is a dict of defaults |
| 107 | +
|
| 108 | + Returns: |
| 109 | + The parameters to send to the LLM |
| 110 | + """ |
| 111 | + |
| 112 | + assert isinstance(model_id, str) |
| 113 | + assert isinstance(params, dict) |
| 114 | + assert isinstance(model_defaults, list) |
| 115 | + |
| 116 | + # Construct defaults for this model. If more than one set of default |
| 117 | + # applies, the last seen default "wins". |
| 118 | + default_union = {} |
| 119 | + for model_default in model_defaults: |
| 120 | + assert isinstance(model_default, dict) |
| 121 | + for model_glob, glob_defaults in model_default.items(): |
| 122 | + if not isinstance(glob_defaults, dict): |
| 123 | + raise ValueError(f"invalid default type {type(glob_defaults)} for model matcher {model_glob}") |
| 124 | + assert isinstance(glob_defaults, dict) |
| 125 | + if fnmatch.fnmatchcase(model_id, model_glob): |
| 126 | + print(f"model {model_id} matches {model_glob}") |
| 127 | + for k, v in glob_defaults.items(): |
| 128 | + default_union[k] = v |
| 129 | + |
| 130 | + # Apply final list of defaults to explicit parameters |
| 131 | + retval = dict(params) |
| 132 | + for k, v in default_union.items(): |
| 133 | + if k not in retval or retval[k] is None: |
| 134 | + retval[k] = v |
| 135 | + return retval |
| 136 | + |
| 137 | +def validate_scope(scope: dict): |
| 138 | + """Throw an exception if any key in scope is invalid""" |
| 139 | + validate_pdl_model_defaults(scope["pdl_model_default_parameters"]) |
| 140 | + |
| 141 | +def validate_pdl_model_defaults(model_defaults: list[dict[str, dict[str, Any]]]): |
| 142 | + """Throw an exception if the model_defaults is not in expected format""" |
| 143 | + |
| 144 | + errors = False |
| 145 | + for model_default in model_defaults: |
| 146 | + assert isinstance(model_default, dict) |
| 147 | + for model_glob, glob_defaults in model_default.items(): |
| 148 | + if not isinstance(glob_defaults, dict): |
| 149 | + raise ValueError(f"invalid defaults {glob_defaults} for model matcher {model_glob}") |
| 150 | + assert isinstance(glob_defaults, dict) |
0 commit comments