@@ -122,7 +122,7 @@ def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typi
122
122
input_ids = encode (question , max_new_tokens )
123
123
original_input_ids = input_ids
124
124
output = input_ids [0 ]
125
- cuda = "" if ( shared .args .cpu or shared .args .deepspeed or shared .args .flexgen ) else ".cuda()"
125
+ cuda = not any (( shared .args .cpu , shared .args .deepspeed , shared .args .flexgen ))
126
126
eos_token_ids = [shared .tokenizer .eos_token_id ] if shared .tokenizer .eos_token_id is not None else []
127
127
if eos_token is not None :
128
128
eos_token_ids .append (int (encode (eos_token )[0 ][- 1 ]))
@@ -132,45 +132,48 @@ def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typi
132
132
t = encode (stopping_string , 0 , add_special_tokens = False )
133
133
stopping_criteria_list .append (_SentinelTokenStoppingCriteria (sentinel_token_ids = t , starting_idx = len (input_ids [0 ])))
134
134
135
+ generate_params = {}
135
136
if not shared .args .flexgen :
136
- generate_params = [
137
- f "max_new_tokens=max_new_tokens" ,
138
- f "eos_token_id= { eos_token_ids } " ,
139
- f "stopping_criteria=stopping_criteria_list" ,
140
- f "do_sample= { do_sample } " ,
141
- f "temperature= { temperature } " ,
142
- f "top_p= { top_p } " ,
143
- f "typical_p= { typical_p } " ,
144
- f "repetition_penalty= { repetition_penalty } " ,
145
- f "top_k= { top_k } " ,
146
- f "min_length= { min_length if shared .args .no_stream else 0 } " ,
147
- f "no_repeat_ngram_size= { no_repeat_ngram_size } " ,
148
- f "num_beams= { num_beams } " ,
149
- f "penalty_alpha= { penalty_alpha } " ,
150
- f "length_penalty= { length_penalty } " ,
151
- f "early_stopping= { early_stopping } " ,
152
- ]
137
+ generate_params . update ({
138
+ "max_new_tokens" : max_new_tokens ,
139
+ "eos_token_id" : eos_token_ids ,
140
+ "stopping_criteria" : stopping_criteria_list ,
141
+ "do_sample" : do_sample ,
142
+ "temperature" : temperature ,
143
+ "top_p" : top_p ,
144
+ "typical_p" : typical_p ,
145
+ "repetition_penalty" : repetition_penalty ,
146
+ "top_k" : top_k ,
147
+ "min_length" : min_length if shared .args .no_stream else 0 ,
148
+ "no_repeat_ngram_size" : no_repeat_ngram_size ,
149
+ "num_beams" : num_beams ,
150
+ "penalty_alpha" : penalty_alpha ,
151
+ "length_penalty" : length_penalty ,
152
+ "early_stopping" : early_stopping ,
153
+ })
153
154
else :
154
- generate_params = [
155
- f "max_new_tokens= { max_new_tokens if shared .args .no_stream else 8 } " ,
156
- f "do_sample= { do_sample } " ,
157
- f "temperature= { temperature } " ,
158
- f "stop= { eos_token_ids [- 1 ]} " ,
159
- ]
155
+ generate_params . update ({
156
+ "max_new_tokens" : max_new_tokens if shared .args .no_stream else 8 ,
157
+ "do_sample" : do_sample ,
158
+ "temperature" : temperature ,
159
+ "stop" : eos_token_ids [- 1 ],
160
+ })
160
161
if shared .args .deepspeed :
161
- generate_params .append ( "synced_gpus=True" )
162
+ generate_params .update ({ "synced_gpus" : True } )
162
163
if shared .soft_prompt :
163
164
inputs_embeds , filler_input_ids = generate_softprompt_input_tensors (input_ids )
164
- generate_params .insert ( 0 , "inputs_embeds=inputs_embeds" )
165
- generate_params .insert ( 0 , "inputs=filler_input_ids" )
165
+ generate_params .update ({ "inputs_embeds" : inputs_embeds } )
166
+ generate_params .update ({ "inputs" : filler_input_ids } )
166
167
else :
167
- generate_params .insert ( 0 , "inputs=input_ids" )
168
+ generate_params .update ({ "inputs" : input_ids } )
168
169
169
170
try :
170
171
# Generate the entire reply at once.
171
172
if shared .args .no_stream :
172
173
with torch .no_grad ():
173
- output = eval (f"shared.model.generate({ ', ' .join (generate_params )} ){ cuda } " )[0 ]
174
+ output = shared .model .generate (** generate_params )[0 ]
175
+ if cuda :
176
+ output = output .cuda ()
174
177
if shared .soft_prompt :
175
178
output = torch .cat ((input_ids [0 ], output [filler_input_ids .shape [1 ]:]))
176
179
@@ -194,7 +197,7 @@ def generate_with_streaming(**kwargs):
194
197
return Iteratorize (generate_with_callback , kwargs , callback = None )
195
198
196
199
yield formatted_outputs (original_question , shared .model_name )
197
- with eval ( f" generate_with_streaming({ ', ' . join ( generate_params ) } )" ) as generator :
200
+ with generate_with_streaming (** generate_params ) as generator :
198
201
for output in generator :
199
202
if shared .soft_prompt :
200
203
output = torch .cat ((input_ids [0 ], output [filler_input_ids .shape [1 ]:]))
@@ -214,7 +217,7 @@ def generate_with_streaming(**kwargs):
214
217
for i in range (max_new_tokens // 8 + 1 ):
215
218
clear_torch_cache ()
216
219
with torch .no_grad ():
217
- output = eval ( f" shared.model.generate({ ', ' . join ( generate_params ) } )" )[0 ]
220
+ output = shared .model .generate (** generate_params )[0 ]
218
221
if shared .soft_prompt :
219
222
output = torch .cat ((input_ids [0 ], output [filler_input_ids .shape [1 ]:]))
220
223
reply = decode (output )
0 commit comments