Skip to content

Commit 0e53a4c

Browse files
authored
Add the saved token report, auto flush data (#401)
Signed-off-by: SimFG <[email protected]>
1 parent 9265b1a commit 0e53a4c

File tree

12 files changed

+263
-56
lines changed

12 files changed

+263
-56
lines changed

gptcache/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
"""gptcache version"""
22
__version__ = "0.1.28"
33

4-
from gptcache.client import Client
54
from gptcache.config import Config
65
from gptcache.core import Cache
76
from gptcache.core import cache

gptcache/adapter/adapter.py

Lines changed: 48 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,11 @@ def adapt(llm_handler, cache_data_convert, update_cache_callback, *args, **kwarg
3636
else: # temperature <= 0
3737
cache_skip = kwargs.pop("cache_skip", False)
3838
cache_factor = kwargs.pop("cache_factor", 1.0)
39-
pre_embedding_res = chat_cache.pre_embedding_func(
39+
pre_embedding_res = time_cal(
40+
chat_cache.pre_embedding_func,
41+
func_name="pre_process",
42+
report_func=chat_cache.report.pre,
43+
)(
4044
kwargs,
4145
extra_param=context.get("pre_embedding_func", None),
4246
prompts=chat_cache.config.prompts,
@@ -81,7 +85,11 @@ def adapt(llm_handler, cache_data_convert, update_cache_callback, *args, **kwarg
8185
else rank_threshold
8286
)
8387
for cache_data in cache_data_list:
84-
ret = chat_cache.data_manager.get_scalar_data(
88+
ret = time_cal(
89+
chat_cache.data_manager.get_scalar_data,
90+
func_name="get_data",
91+
report_func=chat_cache.report.data,
92+
)(
8593
cache_data,
8694
extra_param=context.get("get_scalar_data", None),
8795
session=session,
@@ -112,7 +120,11 @@ def adapt(llm_handler, cache_data_convert, update_cache_callback, *args, **kwarg
112120
"search_result": cache_data,
113121
"embedding": ret.embedding_data,
114122
}
115-
rank = chat_cache.similarity_evaluation.evaluation(
123+
rank = time_cal(
124+
chat_cache.similarity_evaluation.evaluation,
125+
func_name="evaluation",
126+
report_func=chat_cache.report.evaluation,
127+
)(
116128
eval_query_data,
117129
eval_cache_data,
118130
extra_param=context.get("evaluation_func", None),
@@ -129,16 +141,25 @@ def adapt(llm_handler, cache_data_convert, update_cache_callback, *args, **kwarg
129141
cache_answers = sorted(cache_answers, key=lambda x: x[0], reverse=True)
130142
answers_dict = dict((d[1], d[2]) for d in cache_answers)
131143
if len(cache_answers) != 0:
132-
if chat_cache.post_process_messages_func is temperature_softmax:
133-
return_message = chat_cache.post_process_messages_func(
134-
messages=[t[1] for t in cache_answers],
135-
scores=[t[0] for t in cache_answers],
136-
temperature=temperature,
137-
)
138-
else:
139-
return_message = chat_cache.post_process_messages_func(
140-
[t[1] for t in cache_answers]
141-
)
144+
145+
def post_process():
146+
if chat_cache.post_process_messages_func is temperature_softmax:
147+
return_message = chat_cache.post_process_messages_func(
148+
messages=[t[1] for t in cache_answers],
149+
scores=[t[0] for t in cache_answers],
150+
temperature=temperature,
151+
)
152+
else:
153+
return_message = chat_cache.post_process_messages_func(
154+
[t[1] for t in cache_answers]
155+
)
156+
return return_message
157+
158+
return_message = time_cal(
159+
post_process,
160+
func_name="post_process",
161+
report_func=chat_cache.report.post,
162+
)()
142163
chat_cache.report.hint_cache()
143164
if session:
144165
chat_cache.data_manager.add_session(
@@ -156,7 +177,9 @@ def adapt(llm_handler, cache_data_convert, update_cache_callback, *args, **kwarg
156177
llm_handler, cache_data_convert, update_cache_callback, *args, **kwargs
157178
)
158179
else:
159-
llm_data = llm_handler(*args, **kwargs)
180+
llm_data = time_cal(
181+
llm_handler, func_name="llm_request", report_func=chat_cache.report.llm
182+
)(*args, **kwargs)
160183

161184
if cache_enable:
162185
try:
@@ -166,13 +189,23 @@ def update_cache_func(handled_llm_data, question=None):
166189
question = pre_store_data
167190
else:
168191
question.content = pre_store_data
169-
chat_cache.data_manager.save(
192+
time_cal(
193+
chat_cache.data_manager.save,
194+
func_name="save",
195+
report_func=chat_cache.report.save,
196+
)(
170197
question,
171198
handled_llm_data,
172199
embedding_data,
173200
extra_param=context.get("save_func", None),
174201
session=session,
175202
)
203+
if (
204+
chat_cache.report.op_save.count > 0
205+
and chat_cache.report.op_save.count % chat_cache.config.auto_flush
206+
== 0
207+
):
208+
chat_cache.flush()
176209

177210
llm_data = update_cache_callback(
178211
llm_data, update_cache_func, *args, **kwargs

gptcache/adapter/api.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
Cohere,
1818
Rwkv,
1919
PaddleNLP,
20+
UForm,
2021
)
2122
from gptcache.embedding.base import BaseEmbedding
2223
from gptcache.manager import manager_factory
@@ -276,6 +277,8 @@ def _get_model(model_src, model_config=None):
276277
return Rwkv(**model_config)
277278
if model_src == "paddlenlp":
278279
return PaddleNLP(**model_config)
280+
if model_src == "uform":
281+
return UForm(**model_config)
279282

280283

281284
def _get_eval(strategy, kws=None):

gptcache/adapter/openai.py

Lines changed: 33 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from io import BytesIO
66
from typing import Iterator, Any, List
77

8+
from gptcache import cache
89
from gptcache.adapter.adapter import adapt
910
from gptcache.adapter.base import BaseCacheLLM
1011
from gptcache.manager.scalar_data.base import Answer, DataType
@@ -18,6 +19,7 @@
1819
get_image_from_openai_url,
1920
get_audio_text_from_openai_answer,
2021
)
22+
from gptcache.utils.token import token_counter
2123

2224
import_openai()
2325

@@ -80,10 +82,19 @@ def hook_openai_data(it):
8082

8183
@classmethod
8284
def create(cls, *args, **kwargs):
85+
chat_cache = kwargs.get("cache_obj", cache)
86+
enable_token_counter = chat_cache.config.enable_token_counter
87+
8388
def cache_data_convert(cache_data):
89+
if enable_token_counter:
90+
input_token = _num_tokens_from_messages(kwargs.get("messages"))
91+
output_token = token_counter(cache_data)
92+
saved_token = [input_token, output_token]
93+
else:
94+
saved_token = [0, 0]
8495
if kwargs.get("stream", False):
85-
return _construct_stream_resp_from_cache(cache_data)
86-
return _construct_resp_from_cache(cache_data)
96+
return _construct_stream_resp_from_cache(cache_data, saved_token)
97+
return _construct_resp_from_cache(cache_data, saved_token)
8798

8899
kwargs = cls.fill_base_args(**kwargs)
89100
return adapt(
@@ -346,9 +357,10 @@ def create(cls, *args, **kwargs):
346357
return res
347358

348359

349-
def _construct_resp_from_cache(return_message):
360+
def _construct_resp_from_cache(return_message, saved_token):
350361
return {
351362
"gptcache": True,
363+
"saved_token": saved_token,
352364
"choices": [
353365
{
354366
"message": {"role": "assistant", "content": return_message},
@@ -362,7 +374,7 @@ def _construct_resp_from_cache(return_message):
362374
}
363375

364376

365-
def _construct_stream_resp_from_cache(return_message):
377+
def _construct_stream_resp_from_cache(return_message, saved_token):
366378
created = int(time.time())
367379
return [
368380
{
@@ -388,6 +400,7 @@ def _construct_stream_resp_from_cache(return_message):
388400
"choices": [{"delta": {}, "finish_reason": "stop", "index": 0}],
389401
"created": created,
390402
"object": "chat.completion.chunk",
403+
"saved_token": saved_token,
391404
},
392405
]
393406

@@ -447,3 +460,19 @@ def _construct_audio_text_from_cache(return_text):
447460
"gptcache": True,
448461
"text": return_text,
449462
}
463+
464+
465+
def _num_tokens_from_messages(messages):
466+
"""Returns the number of tokens used by a list of messages."""
467+
tokens_per_message = 3
468+
tokens_per_name = 1
469+
470+
num_tokens = 0
471+
for message in messages:
472+
num_tokens += tokens_per_message
473+
for key, value in message.items():
474+
num_tokens += token_counter(value)
475+
if key == "name":
476+
num_tokens += tokens_per_name
477+
num_tokens += 3 # every reply is primed with <|start|>assistant<|message|>
478+
return num_tokens

gptcache/config.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,10 @@ class Config:
1515
:type prompts: Optional[List[str]]
1616
:param template: optional, if the request content will remove the template string and only keep the parameter value in the template
1717
:type template: Optional[str]
18+
:param auto_flush: it will be automatically flushed every time xx pieces of data are added, default to 20
19+
:type auto_flush: int
20+
:param enable_token_counter: enable token counter, default to False
21+
:type enable_token_counter: bool
1822
1923
Example:
2024
.. code-block:: python
@@ -30,6 +34,8 @@ def __init__(
3034
similarity_threshold: float = 0.8,
3135
prompts: Optional[List[str]] = None,
3236
template: Optional[str] = None,
37+
auto_flush: int = 20,
38+
enable_token_counter: bool = True,
3339
):
3440
if similarity_threshold < 0 or similarity_threshold > 1:
3541
raise CacheError(
@@ -39,3 +45,5 @@ def __init__(
3945
self.similarity_threshold = similarity_threshold
4046
self.prompts = prompts
4147
self.template = template
48+
self.auto_flush = auto_flush
49+
self.enable_token_counter = enable_token_counter

gptcache/report.py

Lines changed: 100 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -2,46 +2,126 @@ class Report:
22
"""Get GPTCache report including time and counts for different operations."""
33

44
def __init__(self):
5-
self.embedding_all_time = 0
6-
self.embedding_count = 0
7-
self.search_all_time = 0
8-
self.search_count = 0
5+
self.op_pre = OpCounter()
6+
self.op_embedding = OpCounter()
7+
self.op_search = OpCounter()
8+
self.op_data = OpCounter()
9+
self.op_evaluation = OpCounter()
10+
self.op_post = OpCounter()
11+
self.op_llm = OpCounter()
12+
self.op_save = OpCounter()
913
self.hint_cache_count = 0
1014

15+
def pre(self, delta_time):
16+
"""Pre-process counts and time.
17+
18+
:param delta_time: additional runtime.
19+
"""
20+
self.op_pre.total_time += delta_time
21+
self.op_pre.count += 1
22+
1123
def embedding(self, delta_time):
1224
"""Embedding counts and time.
1325
1426
:param delta_time: additional runtime.
1527
"""
16-
self.embedding_all_time += delta_time
17-
self.embedding_count += 1
28+
self.op_embedding.total_time += delta_time
29+
self.op_embedding.count += 1
1830

1931
def search(self, delta_time):
2032
"""Search counts and time.
2133
2234
:param delta_time: additional runtime.
2335
"""
24-
self.search_all_time += delta_time
25-
self.search_count += 1
36+
self.op_search.total_time += delta_time
37+
self.op_search.count += 1
38+
39+
def data(self, delta_time):
40+
"""Get data counts and time.
41+
42+
:param delta_time: additional runtime.
43+
"""
44+
45+
self.op_data.total_time += delta_time
46+
self.op_data.count += 1
47+
48+
def evaluation(self, delta_time):
49+
"""Evaluation counts and time.
50+
51+
:param delta_time: additional runtime.
52+
"""
53+
self.op_evaluation.total_time += delta_time
54+
self.op_evaluation.count += 1
55+
56+
def post(self, delta_time):
57+
"""Post-process counts and time.
58+
59+
:param delta_time: additional runtime.
60+
"""
61+
self.op_post.total_time += delta_time
62+
self.op_post.count += 1
63+
64+
def llm(self, delta_time):
65+
"""LLM counts and time.
66+
67+
:param delta_time: additional runtime.
68+
"""
69+
self.op_llm.total_time += delta_time
70+
self.op_llm.count += 1
71+
72+
def save(self, delta_time):
73+
"""Save counts and time.
74+
75+
:param delta_time: additional runtime.
76+
"""
77+
self.op_save.total_time += delta_time
78+
self.op_save.count += 1
79+
80+
def average_pre_time(self):
81+
"""Average pre-process time."""
82+
return self.op_pre.average()
2683

2784
def average_embedding_time(self):
2885
"""Average embedding time."""
29-
return round(
30-
self.embedding_all_time / self.embedding_count
31-
if self.embedding_count != 0
32-
else 0,
33-
4,
34-
)
86+
return self.op_embedding.average()
3587

3688
def average_search_time(self):
3789
"""Average search time."""
38-
return round(
39-
self.search_all_time / self.search_count
40-
if self.embedding_count != 0
41-
else 0,
42-
4,
43-
)
90+
return self.op_search.average()
91+
92+
def average_data_time(self):
93+
"""Average data time."""
94+
return self.op_data.average()
95+
96+
def average_evaluation_time(self):
97+
"""Average evaluation time."""
98+
return self.op_evaluation.average()
99+
100+
def average_post_time(self):
101+
"""Average post-process time."""
102+
return self.op_post.average()
103+
104+
def average_llm_time(self):
105+
"""Average LLM time."""
106+
return self.op_llm.average()
107+
108+
def average_save_time(self):
109+
"""Average save time."""
110+
return self.op_save.average()
44111

45112
def hint_cache(self):
46113
"""hint cache count."""
47114
self.hint_cache_count += 1
115+
116+
117+
class OpCounter:
118+
"""Operation counter."""
119+
120+
count = 0
121+
"""Operation count."""
122+
total_time = 0
123+
"""Total time."""
124+
125+
def average(self):
126+
"""Average time."""
127+
return round(self.total_time / self.count, 4) if self.count != 0 else 0

0 commit comments

Comments
 (0)