Skip to content

Commit 9d6a625

Browse files
committed
Add 'hallucinations' filter oobabooga#326
This breaks the API since a new parameter has been added. It should be a one-line fix. See api-example.py.
1 parent 128d18e commit 9d6a625

File tree

5 files changed

+25
-18
lines changed

5 files changed

+25
-18
lines changed

api-example-stream.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ async def run(context):
2626
'top_p': 0.9,
2727
'typical_p': 1,
2828
'repetition_penalty': 1.05,
29+
'encoder_repetition_penalty': 1.0,
2930
'top_k': 0,
3031
'min_length': 0,
3132
'no_repeat_ngram_size': 0,
@@ -59,6 +60,7 @@ async def run(context):
5960
params['top_p'],
6061
params['typical_p'],
6162
params['repetition_penalty'],
63+
params['encoder_repetition_penalty'],
6264
params['top_k'],
6365
params['min_length'],
6466
params['no_repeat_ngram_size'],

api-example.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
'top_p': 0.9,
2525
'typical_p': 1,
2626
'repetition_penalty': 1.05,
27+
'encoder_repetition_penalty': 1.0,
2728
'top_k': 0,
2829
'min_length': 0,
2930
'no_repeat_ngram_size': 0,
@@ -45,6 +46,7 @@
4546
params['top_p'],
4647
params['typical_p'],
4748
params['repetition_penalty'],
49+
params['encoder_repetition_penalty'],
4850
params['top_k'],
4951
params['min_length'],
5052
params['no_repeat_ngram_size'],

modules/chat.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ def extract_message_from_reply(question, reply, name1, name2, check, impersonate
9797
def stop_everything_event():
9898
shared.stop_everything = True
9999

100-
def chatbot_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, chat_prompt_size, chat_generation_attempts=1, regenerate=False):
100+
def chatbot_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, encoder_repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, chat_prompt_size, chat_generation_attempts=1, regenerate=False):
101101
shared.stop_everything = False
102102
just_started = True
103103
eos_token = '\n' if check else None
@@ -133,7 +133,7 @@ def chatbot_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical
133133
# Generate
134134
reply = ''
135135
for i in range(chat_generation_attempts):
136-
for reply in generate_reply(f"{prompt}{' ' if len(reply) > 0 else ''}{reply}", max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, eos_token=eos_token, stopping_string=f"\n{name1}:"):
136+
for reply in generate_reply(f"{prompt}{' ' if len(reply) > 0 else ''}{reply}", max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, encoder_repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, eos_token=eos_token, stopping_string=f"\n{name1}:"):
137137

138138
# Extracting the reply
139139
reply, next_character_found = extract_message_from_reply(prompt, reply, name1, name2, check)
@@ -160,7 +160,7 @@ def chatbot_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical
160160

161161
yield shared.history['visible']
162162

163-
def impersonate_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, chat_prompt_size, chat_generation_attempts=1):
163+
def impersonate_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, encoder_repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, chat_prompt_size, chat_generation_attempts=1):
164164
eos_token = '\n' if check else None
165165

166166
if 'pygmalion' in shared.model_name.lower():
@@ -172,26 +172,26 @@ def impersonate_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typ
172172
# Yield *Is typing...*
173173
yield shared.processing_message
174174
for i in range(chat_generation_attempts):
175-
for reply in generate_reply(prompt+reply, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, eos_token=eos_token, stopping_string=f"\n{name2}:"):
175+
for reply in generate_reply(prompt+reply, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, encoder_repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, eos_token=eos_token, stopping_string=f"\n{name2}:"):
176176
reply, next_character_found = extract_message_from_reply(prompt, reply, name1, name2, check, impersonate=True)
177177
yield reply
178178
if next_character_found:
179179
break
180180
yield reply
181181

182-
def cai_chatbot_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, chat_prompt_size, chat_generation_attempts=1):
183-
for _history in chatbot_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, chat_prompt_size, chat_generation_attempts):
182+
def cai_chatbot_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, encoder_repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, chat_prompt_size, chat_generation_attempts=1):
183+
for _history in chatbot_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, encoder_repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, chat_prompt_size, chat_generation_attempts):
184184
yield generate_chat_html(_history, name1, name2, shared.character)
185185

186-
def regenerate_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, chat_prompt_size, chat_generation_attempts=1):
186+
def regenerate_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, encoder_repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, chat_prompt_size, chat_generation_attempts=1):
187187
if (shared.character != 'None' and len(shared.history['visible']) == 1) or len(shared.history['internal']) == 0:
188188
yield generate_chat_output(shared.history['visible'], name1, name2, shared.character)
189189
else:
190190
last_visible = shared.history['visible'].pop()
191191
last_internal = shared.history['internal'].pop()
192192
# Yield '*Is typing...*'
193193
yield generate_chat_output(shared.history['visible']+[[last_visible[0], shared.processing_message]], name1, name2, shared.character)
194-
for _history in chatbot_wrapper(last_internal[0], max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, chat_prompt_size, chat_generation_attempts, regenerate=True):
194+
for _history in chatbot_wrapper(last_internal[0], max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, encoder_repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, chat_prompt_size, chat_generation_attempts, regenerate=True):
195195
if shared.args.cai_chat:
196196
shared.history['visible'][-1] = [last_visible[0], _history[-1][1]]
197197
else:

modules/text_generation.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ def clear_torch_cache():
8989
if not shared.args.cpu:
9090
torch.cuda.empty_cache()
9191

92-
def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, eos_token=None, stopping_string=None):
92+
def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, encoder_repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, eos_token=None, stopping_string=None):
9393
clear_torch_cache()
9494
t0 = time.time()
9595

@@ -143,6 +143,7 @@ def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typi
143143
"top_p": top_p,
144144
"typical_p": typical_p,
145145
"repetition_penalty": repetition_penalty,
146+
"encoder_repetition_penalty": encoder_repetition_penalty,
146147
"top_k": top_k,
147148
"min_length": min_length if shared.args.no_stream else 0,
148149
"no_repeat_ngram_size": no_repeat_ngram_size,

server.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@ def load_preset_values(preset_menu, return_dict=False):
6666
'top_p': 1,
6767
'typical_p': 1,
6868
'repetition_penalty': 1,
69+
'encoder_repetition_penalty': 1,
6970
'top_k': 50,
7071
'num_beams': 1,
7172
'penalty_alpha': 0,
@@ -86,7 +87,7 @@ def load_preset_values(preset_menu, return_dict=False):
8687
if return_dict:
8788
return generate_params
8889
else:
89-
return generate_params['do_sample'], generate_params['temperature'], generate_params['top_p'], generate_params['typical_p'], generate_params['repetition_penalty'], generate_params['top_k'], generate_params['min_length'], generate_params['no_repeat_ngram_size'], generate_params['num_beams'], generate_params['penalty_alpha'], generate_params['length_penalty'], generate_params['early_stopping']
90+
return generate_params['do_sample'], generate_params['temperature'], generate_params['top_p'], generate_params['typical_p'], generate_params['repetition_penalty'], generate_params['encoder_repetition_penalty'], generate_params['top_k'], generate_params['min_length'], generate_params['no_repeat_ngram_size'], generate_params['num_beams'], generate_params['penalty_alpha'], generate_params['length_penalty'], generate_params['early_stopping']
9091

9192
def upload_soft_prompt(file):
9293
with zipfile.ZipFile(io.BytesIO(file)) as zf:
@@ -117,14 +118,15 @@ def create_settings_menus(default_preset):
117118
with gr.Row():
118119
with gr.Column():
119120
shared.gradio['temperature'] = gr.Slider(0.01, 1.99, value=generate_params['temperature'], step=0.01, label='temperature')
120-
shared.gradio['repetition_penalty'] = gr.Slider(1.0, 2.99, value=generate_params['repetition_penalty'],step=0.01,label='repetition_penalty')
121-
shared.gradio['top_k'] = gr.Slider(0,200,value=generate_params['top_k'],step=1,label='top_k')
122121
shared.gradio['top_p'] = gr.Slider(0.0,1.0,value=generate_params['top_p'],step=0.01,label='top_p')
123-
with gr.Column():
124-
shared.gradio['do_sample'] = gr.Checkbox(value=generate_params['do_sample'], label='do_sample')
122+
shared.gradio['top_k'] = gr.Slider(0,200,value=generate_params['top_k'],step=1,label='top_k')
125123
shared.gradio['typical_p'] = gr.Slider(0.0,1.0,value=generate_params['typical_p'],step=0.01,label='typical_p')
124+
with gr.Column():
125+
shared.gradio['repetition_penalty'] = gr.Slider(1.0, 1.5, value=generate_params['repetition_penalty'],step=0.01,label='repetition_penalty')
126+
shared.gradio['encoder_repetition_penalty'] = gr.Slider(0.8, 1.5, value=generate_params['encoder_repetition_penalty'],step=0.01,label='encoder_repetition_penalty')
126127
shared.gradio['no_repeat_ngram_size'] = gr.Slider(0, 20, step=1, value=generate_params['no_repeat_ngram_size'], label='no_repeat_ngram_size')
127128
shared.gradio['min_length'] = gr.Slider(0, 2000, step=1, value=generate_params['min_length'] if shared.args.no_stream else 0, label='min_length', interactive=shared.args.no_stream)
129+
shared.gradio['do_sample'] = gr.Checkbox(value=generate_params['do_sample'], label='do_sample')
128130

129131
gr.Markdown('Contrastive search:')
130132
shared.gradio['penalty_alpha'] = gr.Slider(0, 5, value=generate_params['penalty_alpha'], label='penalty_alpha')
@@ -147,7 +149,7 @@ def create_settings_menus(default_preset):
147149
shared.gradio['upload_softprompt'] = gr.File(type='binary', file_types=['.zip'])
148150

149151
shared.gradio['model_menu'].change(load_model_wrapper, [shared.gradio['model_menu']], [shared.gradio['model_menu']], show_progress=True)
150-
shared.gradio['preset_menu'].change(load_preset_values, [shared.gradio['preset_menu']], [shared.gradio['do_sample'], shared.gradio['temperature'], shared.gradio['top_p'], shared.gradio['typical_p'], shared.gradio['repetition_penalty'], shared.gradio['top_k'], shared.gradio['min_length'], shared.gradio['no_repeat_ngram_size'], shared.gradio['num_beams'], shared.gradio['penalty_alpha'], shared.gradio['length_penalty'], shared.gradio['early_stopping']])
152+
shared.gradio['preset_menu'].change(load_preset_values, [shared.gradio['preset_menu']], [shared.gradio['do_sample'], shared.gradio['temperature'], shared.gradio['top_p'], shared.gradio['typical_p'], shared.gradio['repetition_penalty'], shared.gradio['encoder_repetition_penalty'], shared.gradio['top_k'], shared.gradio['min_length'], shared.gradio['no_repeat_ngram_size'], shared.gradio['num_beams'], shared.gradio['penalty_alpha'], shared.gradio['length_penalty'], shared.gradio['early_stopping']])
151153
shared.gradio['softprompts_menu'].change(load_soft_prompt, [shared.gradio['softprompts_menu']], [shared.gradio['softprompts_menu']], show_progress=True)
152154
shared.gradio['upload_softprompt'].upload(upload_soft_prompt, [shared.gradio['upload_softprompt']], [shared.gradio['softprompts_menu']])
153155

@@ -262,7 +264,7 @@ def create_settings_menus(default_preset):
262264
shared.gradio['chat_generation_attempts'] = gr.Slider(minimum=shared.settings['chat_generation_attempts_min'], maximum=shared.settings['chat_generation_attempts_max'], value=shared.settings['chat_generation_attempts'], step=1, label='Generation attempts (for longer replies)')
263265
create_settings_menus(default_preset)
264266

265-
shared.input_params = [shared.gradio[k] for k in ['textbox', 'max_new_tokens', 'do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping', 'name1', 'name2', 'context', 'check', 'chat_prompt_size_slider', 'chat_generation_attempts']]
267+
shared.input_params = [shared.gradio[k] for k in ['textbox', 'max_new_tokens', 'do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping', 'name1', 'name2', 'context', 'check', 'chat_prompt_size_slider', 'chat_generation_attempts']]
266268
if shared.args.extensions is not None:
267269
with gr.Tab('Extensions'):
268270
extensions_module.create_extensions_block()
@@ -329,7 +331,7 @@ def create_settings_menus(default_preset):
329331
if shared.args.extensions is not None:
330332
extensions_module.create_extensions_block()
331333

332-
shared.input_params = [shared.gradio[k] for k in ['textbox', 'max_new_tokens', 'do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping']]
334+
shared.input_params = [shared.gradio[k] for k in ['textbox', 'max_new_tokens', 'do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping']]
333335
output_params = [shared.gradio[k] for k in ['textbox', 'markdown', 'html']]
334336
gen_events.append(shared.gradio['Generate'].click(generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream, api_name='textgen'))
335337
gen_events.append(shared.gradio['textbox'].submit(generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream))
@@ -361,7 +363,7 @@ def create_settings_menus(default_preset):
361363
with gr.Tab('HTML'):
362364
shared.gradio['html'] = gr.HTML()
363365

364-
shared.input_params = [shared.gradio[k] for k in ['textbox', 'max_new_tokens', 'do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping']]
366+
shared.input_params = [shared.gradio[k] for k in ['textbox', 'max_new_tokens', 'do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping']]
365367
output_params = [shared.gradio[k] for k in ['output_textbox', 'markdown', 'html']]
366368
gen_events.append(shared.gradio['Generate'].click(generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream, api_name='textgen'))
367369
gen_events.append(shared.gradio['textbox'].submit(generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream))

0 commit comments

Comments
 (0)