Skip to content

Commit 03a2787

Browse files
authored
Add the pre-function of handling long prompt and Update the context processor doc (#395)
Signed-off-by: SimFG <[email protected]>
1 parent 873fca7 commit 03a2787

File tree

8 files changed

+188
-10
lines changed

8 files changed

+188
-10
lines changed

docs/conf.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@
5656
intersphinx_mapping = {
5757
"torch": ("https://pytorch.org/docs/stable/", None),
5858
"numpy": ("https://numpy.org/devdocs/", None),
59-
"python": ("https://docs.python.org/3", None),
59+
"python": ("https://docs.python.org/3.8/", None),
6060
}
6161

6262
autodoc_member_order = "bysource"

gptcache/adapter/adapter.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ def adapt(llm_handler, cache_data_convert, update_cache_callback, *args, **kwarg
4040
kwargs,
4141
extra_param=context.get("pre_embedding_func", None),
4242
prompts=chat_cache.config.prompts,
43+
cache_config=chat_cache.config,
4344
)
4445
if isinstance(pre_embedding_res, tuple):
4546
pre_store_data = pre_embedding_res[0]

gptcache/config.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@ class Config:
1313
:type similarity_threshold: float
1414
:param prompts: optional, if the request content will remove the prompt string when the request contains the prompt list
1515
:type prompts: Optional[List[str]]
16+
:param template: optional, if the request content will remove the template string and only keep the parameter value in the template
17+
:type template: Optional[str]
1618
1719
Example:
1820
.. code-block:: python
@@ -26,7 +28,8 @@ def __init__(
2628
self,
2729
log_time_func: Optional[Callable[[str, float], None]] = None,
2830
similarity_threshold: float = 0.8,
29-
prompts: Optional[List[str]] = None
31+
prompts: Optional[List[str]] = None,
32+
template: Optional[str] = None,
3033
):
3134
if similarity_threshold < 0 or similarity_threshold > 1:
3235
raise CacheError(
@@ -35,3 +38,4 @@ def __init__(
3538
self.log_time_func = log_time_func
3639
self.similarity_threshold = similarity_threshold
3740
self.prompts = prompts
41+
self.template = template

gptcache/processor/context/concat_context.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,26 @@
44

55

66
class ConcatContextProcess(ContextProcess):
7-
"""A concat context processor simply concat the context
7+
"""A concat context processor simply concat the context.
8+
Generally used with rwkv embedding, because rwkv can input almost infinitely long
9+
10+
Example:
11+
.. code-block:: python
12+
13+
from gptcache.manager import manager_factory
14+
from gptcache.processor.context.concat_context import ConcatContextProcess
15+
16+
context_process = ConcatContextProcess()
17+
rwkv_embedding = Rwkv()
18+
data_manager = manager_factory(
19+
"sqlite,faiss",
20+
vector_params={"dimension": rwkv_embedding.dimension},
21+
)
22+
cache.init(
23+
pre_embedding_func=context_process.pre_process,
24+
embedding_func=rwkv_embedding.to_embeddings,
25+
data_manager=data_manager,
26+
)
827
"""
928

1029
content: str = ""

gptcache/processor/context/selective_context.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,13 @@ class SelectiveContextProcess(ContextProcess):
2626
2727
more details: https://github.com/liyucheng09/Selective_Context
2828
29+
Example:
30+
.. code-block:: python
31+
32+
from gptcache.processor.context.selective_context import SelectiveContextProcess
33+
34+
context_process = SelectiveContextProcess()
35+
cache.init(pre_embedding_func=context_process.pre_process)
2936
"""
3037

3138
content: str = ""

gptcache/processor/context/summarization_context.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,14 @@ class SummarizationContextProcess(ContextProcess):
2020
:type tokenizer: transformers.PreTrainedTokenizer
2121
:param target_length: The length of the summarized text.
2222
:type target_length: int
23+
24+
Example:
25+
.. code-block:: python
26+
27+
from gptcache.processor.context.summarization_context import SummarizationContextProcess
28+
29+
context_process = SummarizationContextProcess()
30+
cache.init(pre_embedding_func=context_process.pre_process)
2331
"""
2432
def __init__(self, summarizer=transformers.pipeline("summarization", model="facebook/bart-large-cnn"),
2533
tokenizer=None, target_length=512):

gptcache/processor/pre.py

Lines changed: 94 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import re
2+
import string
23
from typing import Dict, Any
34

45

@@ -47,19 +48,108 @@ def last_content_without_prompt(data: Dict[str, Any], **params: Dict[str, Any])
4748
return new_content_str
4849

4950

51+
def _get_pattern_value(pattern_str: str, value_str: str):
52+
literal_text_arr = []
53+
field_name_arr = []
54+
for literal_text, field_name, _, _ in string.Formatter().parse(pattern_str):
55+
literal_text_arr.append(literal_text)
56+
if field_name is not None:
57+
field_name_arr.append(
58+
field_name if field_name else str(len(field_name_arr))
59+
)
60+
61+
pattern_values = {}
62+
last_end = 0
63+
for i, literal_text in enumerate(literal_text_arr):
64+
start = value_str.find(literal_text, last_end)
65+
if i == len(literal_text_arr) - 1:
66+
end = len(value_str)
67+
else:
68+
end = value_str.find(literal_text_arr[i + 1], start + 1)
69+
if start == -1 or end == -1:
70+
break
71+
start += len(literal_text)
72+
pattern_values[field_name_arr[i]] = value_str[start:end]
73+
last_end = end
74+
return pattern_values
75+
76+
77+
def last_content_without_template(data: Dict[str, Any], **params: Dict[str, Any]) -> Any:
78+
"""get the last content's template values of the message list without template content.
79+
80+
When considering a cache agent or chain, the majority of the content consists of template content,
81+
while the essential information is simply a list of parameters within the template.
82+
In this way, the cache key is composed of a string made up of all the parameter values in the list.
83+
84+
WARNING: Two parameters without intervals cannot appear in the template,
85+
for example: template = "{foo}{hoo}" is not supported,
86+
but template = "{foo}:{hoo}" is supported
87+
88+
:param data: the user llm request data
89+
:type data: Dict[str, Any]
90+
91+
:Example with str template:
92+
.. code-block:: python
93+
94+
from gptcache import Config
95+
from gptcache.processor.pre import last_content_without_template
96+
97+
template_obj = "tell me a joke about {subject}"
98+
prompt = template_obj.format(subject="animal")
99+
value = last_content_without_template(
100+
data={"messages": [{"content": prompt}]}, cache_config=Config(template=template_obj)
101+
)
102+
print(value)
103+
# ['animal']
104+
105+
:Example with langchain template:
106+
.. code-block:: python
107+
108+
from langchain import PromptTemplate
109+
110+
from gptcache import Config
111+
from gptcache.processor.pre import last_content_without_template
112+
113+
template_obj = PromptTemplate.from_template("tell me a joke about {subject}")
114+
prompt = template_obj.format(subject="animal")
115+
116+
value = last_content_without_template(
117+
data={"messages": [{"content": prompt}]},
118+
cache_config=Config(template=template_obj.template),
119+
)
120+
print(value)
121+
# ['animal']
122+
123+
NOTE: At present, only the simple PromptTemplate in langchain is supported.
124+
For ChatPromptTemplate, it needs to be adjusted according to the template array.
125+
If you need to use it, you need to pass in the final dialog template yourself.
126+
The reason why it cannot be advanced is that ChatPromptTemplate
127+
does not provide a method to directly return the template string.
128+
"""
129+
last_content_str = data.get("messages")[-1]["content"]
130+
cache_config = params.get("cache_config", None)
131+
if not (cache_config and cache_config.template):
132+
return last_content_str
133+
134+
pattern_value = _get_pattern_value(cache_config.template, last_content_str)
135+
return str(list(pattern_value.values()))
136+
137+
50138
def all_content(data: Dict[str, Any], **_: Dict[str, Any]) -> Any:
51-
""" get all content of the message list
139+
"""get all content of the message list
52140
53141
:param data: the user llm request data
54142
:type data: Dict[str, Any]
55143
56-
Example:
144+
:Example:
57145
.. code-block:: python
58146
59147
from gptcache.processor.pre import all_content
60148
61-
content = all_content({"messages": [{"content": "foo1"}, {"content": "foo2"}]})
62-
# content = "foo1\nfoo2"
149+
content = all_content(
150+
{"messages": [{"content": "foo1"}, {"content": "foo2"}]}
151+
)
152+
# content = "foo1\\nfoo2"
63153
"""
64154
s = ""
65155
messages = data.get("messages")

tests/unit_tests/adapter/test_langchain_models.py

Lines changed: 52 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,19 @@
11
import os
2+
import random
23
from unittest.mock import patch
34

4-
from gptcache import Cache
5+
from gptcache import Cache, Config
6+
from gptcache.adapter import openai
7+
from gptcache.adapter.api import init_similar_cache, get
58
from gptcache.adapter.langchain_models import LangChainLLMs, LangChainChat, _cache_msg_data_convert
6-
from gptcache.processor.pre import get_prompt
9+
from gptcache.processor.pre import get_prompt, last_content_without_template
710
from gptcache.utils import import_pydantic, import_langchain
11+
from gptcache.utils.response import get_message_from_openai_answer
812

913
import_pydantic()
1014
import_langchain()
1115

12-
from langchain import OpenAI
16+
from langchain import OpenAI, PromptTemplate
1317
from langchain.chat_models import ChatOpenAI
1418
from langchain.schema import HumanMessage
1519

@@ -102,3 +106,48 @@ def test_langchain_chats():
102106

103107
answer = chat(messages=question, cache_obj=llm_cache)
104108
assert answer == _cache_msg_data_convert(msg).generations[0].message
109+
110+
111+
def test_last_content_without_template():
112+
string_prompt = PromptTemplate.from_template("tell me a joke about {subject}")
113+
template = string_prompt.template
114+
cache_obj = Cache()
115+
data_dir = str(random.random())
116+
init_similar_cache(data_dir=data_dir, cache_obj=cache_obj, pre_func=last_content_without_template, config=Config(template=template))
117+
118+
subject_str = "animal"
119+
expect_answer = "this is a joke"
120+
121+
with patch("openai.ChatCompletion.create") as mock_create:
122+
datas = {
123+
"choices": [
124+
{
125+
"message": {"content": expect_answer, "role": "assistant"},
126+
"finish_reason": "stop",
127+
"index": 0,
128+
}
129+
],
130+
"created": 1677825464,
131+
"id": "chatcmpl-6ptKyqKOGXZT6iQnqiXAH8adNLUzD",
132+
"model": "gpt-3.5-turbo-0301",
133+
"object": "chat.completion.chunk",
134+
}
135+
mock_create.return_value = datas
136+
137+
response = openai.ChatCompletion.create(
138+
model="gpt-3.5-turbo",
139+
messages=[
140+
{"role": "system", "content": "You are a helpful assistant."},
141+
{"role": "user", "content": string_prompt.format(subject=subject_str)},
142+
],
143+
cache_obj=cache_obj,
144+
)
145+
assert get_message_from_openai_answer(response) == expect_answer, response
146+
147+
cache_obj.flush()
148+
149+
init_similar_cache(data_dir=data_dir, cache_obj=cache_obj)
150+
151+
cache_res = get(str([subject_str]), cache_obj=cache_obj)
152+
print(str([subject_str]))
153+
assert cache_res == expect_answer, cache_res

0 commit comments

Comments
 (0)