Skip to content

Commit 5888034

Browse files
authored
Merge branch 'main' into xpu_support_pr
2 parents bd10f79 + ed72e92 commit 5888034

40 files changed

+2481
-1094
lines changed

Cargo.lock

Lines changed: 107 additions & 141 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,15 @@ members = [
99
resolver = "2"
1010

1111
[workspace.package]
12-
version = "2.0.0"
12+
version = "2.0.1"
1313
edition = "2021"
1414
authors = ["Olivier Dehaene"]
1515
homepage = "https://github.com/huggingface/text-generation-inference"
1616

17+
[workspace.dependencies]
18+
tokenizers = { version = "0.19.1", features = ["http"] }
19+
hf-hub = { version = "0.3.1", features = ["tokio"] }
20+
1721
[profile.release]
1822
debug = 1
1923
incremental = true

benchmark/Cargo.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,9 @@ serde_json = "1.0"
2323
tabled = "0.14.0"
2424
text-generation-client = { path = "../router/client" }
2525
thiserror = "1.0.48"
26-
tokenizers = { version = "0.14.0", features = ["http"] }
26+
tokenizers = { workspace = true }
2727
tokio = { version = "1.32.0", features = ["rt", "rt-multi-thread", "parking_lot", "signal", "sync", "macros"] }
2828
tui = {package = "ratatui", version = "0.23", default-features = false, features = ["crossterm"]}
2929
tracing = "0.1.37"
3030
tracing-subscriber = { version = "0.3.17", features = ["json", "env-filter"] }
31-
hf-hub = "0.3.1"
31+
hf-hub = { workspace = true }

clients/python/tests/conftest.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,11 @@ def flan_t5_xxl():
99
return "google/flan-t5-xxl"
1010

1111

12+
@pytest.fixture
13+
def llama_7b():
14+
return "meta-llama/Llama-2-7b-chat-hf"
15+
16+
1217
@pytest.fixture
1318
def fake_model():
1419
return "fake/model"
@@ -34,6 +39,11 @@ def flan_t5_xxl_url(base_url, flan_t5_xxl):
3439
return f"{base_url}/{flan_t5_xxl}"
3540

3641

42+
@pytest.fixture
43+
def llama_7b_url(base_url, llama_7b):
44+
return f"{base_url}/{llama_7b}"
45+
46+
3747
@pytest.fixture
3848
def fake_url(base_url, fake_model):
3949
return f"{base_url}/{fake_model}"

clients/python/tests/test_client.py

Lines changed: 35 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -5,24 +5,24 @@
55
from text_generation.types import FinishReason, InputToken
66

77

8-
def test_generate(flan_t5_xxl_url, hf_headers):
9-
client = Client(flan_t5_xxl_url, hf_headers)
8+
def test_generate(llama_7b_url, hf_headers):
9+
client = Client(llama_7b_url, hf_headers)
1010
response = client.generate("test", max_new_tokens=1, decoder_input_details=True)
1111

12-
assert response.generated_text == ""
12+
assert response.generated_text == "_"
1313
assert response.details.finish_reason == FinishReason.Length
1414
assert response.details.generated_tokens == 1
1515
assert response.details.seed is None
16-
assert len(response.details.prefill) == 1
17-
assert response.details.prefill[0] == InputToken(id=0, text="<pad>", logprob=None)
16+
assert len(response.details.prefill) == 2
17+
assert response.details.prefill[0] == InputToken(id=1, text="<s>", logprob=None)
1818
assert len(response.details.tokens) == 1
19-
assert response.details.tokens[0].id == 3
20-
assert response.details.tokens[0].text == " "
19+
assert response.details.tokens[0].id == 29918
20+
assert response.details.tokens[0].text == "_"
2121
assert not response.details.tokens[0].special
2222

2323

24-
def test_generate_best_of(flan_t5_xxl_url, hf_headers):
25-
client = Client(flan_t5_xxl_url, hf_headers)
24+
def test_generate_best_of(llama_7b_url, hf_headers):
25+
client = Client(llama_7b_url, hf_headers)
2626
response = client.generate(
2727
"test", max_new_tokens=1, best_of=2, do_sample=True, decoder_input_details=True
2828
)
@@ -39,22 +39,22 @@ def test_generate_not_found(fake_url, hf_headers):
3939
client.generate("test")
4040

4141

42-
def test_generate_validation_error(flan_t5_xxl_url, hf_headers):
43-
client = Client(flan_t5_xxl_url, hf_headers)
42+
def test_generate_validation_error(llama_7b_url, hf_headers):
43+
client = Client(llama_7b_url, hf_headers)
4444
with pytest.raises(ValidationError):
4545
client.generate("test", max_new_tokens=10_000)
4646

4747

48-
def test_generate_stream(flan_t5_xxl_url, hf_headers):
49-
client = Client(flan_t5_xxl_url, hf_headers)
48+
def test_generate_stream(llama_7b_url, hf_headers):
49+
client = Client(llama_7b_url, hf_headers)
5050
responses = [
5151
response for response in client.generate_stream("test", max_new_tokens=1)
5252
]
5353

5454
assert len(responses) == 1
5555
response = responses[0]
5656

57-
assert response.generated_text == ""
57+
assert response.generated_text == "_"
5858
assert response.details.finish_reason == FinishReason.Length
5959
assert response.details.generated_tokens == 1
6060
assert response.details.seed is None
@@ -66,34 +66,37 @@ def test_generate_stream_not_found(fake_url, hf_headers):
6666
list(client.generate_stream("test"))
6767

6868

69-
def test_generate_stream_validation_error(flan_t5_xxl_url, hf_headers):
70-
client = Client(flan_t5_xxl_url, hf_headers)
69+
def test_generate_stream_validation_error(llama_7b_url, hf_headers):
70+
client = Client(llama_7b_url, hf_headers)
7171
with pytest.raises(ValidationError):
7272
list(client.generate_stream("test", max_new_tokens=10_000))
7373

7474

7575
@pytest.mark.asyncio
76-
async def test_generate_async(flan_t5_xxl_url, hf_headers):
77-
client = AsyncClient(flan_t5_xxl_url, hf_headers)
76+
async def test_generate_async(llama_7b_url, hf_headers):
77+
client = AsyncClient(llama_7b_url, hf_headers)
7878
response = await client.generate(
7979
"test", max_new_tokens=1, decoder_input_details=True
8080
)
8181

82-
assert response.generated_text == ""
82+
assert response.generated_text == "_"
8383
assert response.details.finish_reason == FinishReason.Length
8484
assert response.details.generated_tokens == 1
8585
assert response.details.seed is None
86-
assert len(response.details.prefill) == 1
87-
assert response.details.prefill[0] == InputToken(id=0, text="<pad>", logprob=None)
86+
assert len(response.details.prefill) == 2
87+
assert response.details.prefill[0] == InputToken(id=1, text="<s>", logprob=None)
88+
assert response.details.prefill[1] == InputToken(
89+
id=1243, text="test", logprob=-10.96875
90+
)
8891
assert len(response.details.tokens) == 1
89-
assert response.details.tokens[0].id == 3
90-
assert response.details.tokens[0].text == " "
92+
assert response.details.tokens[0].id == 29918
93+
assert response.details.tokens[0].text == "_"
9194
assert not response.details.tokens[0].special
9295

9396

9497
@pytest.mark.asyncio
95-
async def test_generate_async_best_of(flan_t5_xxl_url, hf_headers):
96-
client = AsyncClient(flan_t5_xxl_url, hf_headers)
98+
async def test_generate_async_best_of(llama_7b_url, hf_headers):
99+
client = AsyncClient(llama_7b_url, hf_headers)
97100
response = await client.generate(
98101
"test", max_new_tokens=1, best_of=2, do_sample=True, decoder_input_details=True
99102
)
@@ -112,23 +115,23 @@ async def test_generate_async_not_found(fake_url, hf_headers):
112115

113116

114117
@pytest.mark.asyncio
115-
async def test_generate_async_validation_error(flan_t5_xxl_url, hf_headers):
116-
client = AsyncClient(flan_t5_xxl_url, hf_headers)
118+
async def test_generate_async_validation_error(llama_7b_url, hf_headers):
119+
client = AsyncClient(llama_7b_url, hf_headers)
117120
with pytest.raises(ValidationError):
118121
await client.generate("test", max_new_tokens=10_000)
119122

120123

121124
@pytest.mark.asyncio
122-
async def test_generate_stream_async(flan_t5_xxl_url, hf_headers):
123-
client = AsyncClient(flan_t5_xxl_url, hf_headers)
125+
async def test_generate_stream_async(llama_7b_url, hf_headers):
126+
client = AsyncClient(llama_7b_url, hf_headers)
124127
responses = [
125128
response async for response in client.generate_stream("test", max_new_tokens=1)
126129
]
127130

128131
assert len(responses) == 1
129132
response = responses[0]
130133

131-
assert response.generated_text == ""
134+
assert response.generated_text == "_"
132135
assert response.details.finish_reason == FinishReason.Length
133136
assert response.details.generated_tokens == 1
134137
assert response.details.seed is None
@@ -143,8 +146,8 @@ async def test_generate_stream_async_not_found(fake_url, hf_headers):
143146

144147

145148
@pytest.mark.asyncio
146-
async def test_generate_stream_async_validation_error(flan_t5_xxl_url, hf_headers):
147-
client = AsyncClient(flan_t5_xxl_url, hf_headers)
149+
async def test_generate_stream_async_validation_error(llama_7b_url, hf_headers):
150+
client = AsyncClient(llama_7b_url, hf_headers)
148151
with pytest.raises(ValidationError):
149152
async for _ in client.generate_stream("test", max_new_tokens=10_000):
150153
pass

clients/python/text_generation/types.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,17 @@ class ChatCompletionComplete(BaseModel):
5959
usage: Optional[Any] = None
6060

6161

62+
class CompletionComplete(BaseModel):
63+
# Index of the chat completion
64+
index: int
65+
# Message associated with the chat completion
66+
text: str
67+
# Log probabilities for the chat completion
68+
logprobs: Optional[Any]
69+
# Reason for completion
70+
finish_reason: str
71+
72+
6273
class Function(BaseModel):
6374
name: Optional[str]
6475
arguments: str
@@ -104,6 +115,16 @@ class ChatComplete(BaseModel):
104115
usage: Any
105116

106117

118+
class Completion(BaseModel):
119+
# Completion details
120+
id: str
121+
object: str
122+
created: int
123+
model: str
124+
system_fingerprint: str
125+
choices: List[CompletionComplete]
126+
127+
107128
class ChatRequest(BaseModel):
108129
# Model identifier
109130
model: str

docs/openapi.json

Lines changed: 30 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
"name": "Apache 2.0",
1111
"url": "https://www.apache.org/licenses/LICENSE-2.0"
1212
},
13-
"version": "2.0.0"
13+
"version": "2.0.1"
1414
},
1515
"paths": {
1616
"/": {
@@ -408,9 +408,14 @@
408408
},
409409
"responses": {
410410
"200": {
411-
"description": "Generated Text",
411+
"description": "Generated Chat Completion",
412412
"content": {
413413
"application/json": {
414+
"schema": {
415+
"$ref": "#/components/schemas/ChatCompletion"
416+
}
417+
},
418+
"text/event-stream": {
414419
"schema": {
415420
"$ref": "#/components/schemas/ChatCompletionChunk"
416421
}
@@ -492,11 +497,16 @@
492497
},
493498
"responses": {
494499
"200": {
495-
"description": "Generated Text",
500+
"description": "Generated Chat Completion",
496501
"content": {
497502
"application/json": {
498503
"schema": {
499-
"$ref": "#/components/schemas/ChatCompletionChunk"
504+
"$ref": "#/components/schemas/Completion"
505+
}
506+
},
507+
"text/event-stream": {
508+
"schema": {
509+
"$ref": "#/components/schemas/CompletionCompleteChunk"
500510
}
501511
}
502512
}
@@ -930,7 +940,7 @@
930940
"tool_prompt": {
931941
"type": "string",
932942
"description": "A prompt to be appended before the tools",
933-
"example": "\"Based on the conversation, please choose the most appropriate tool to use: \"",
943+
"example": "\"You will be presented with a JSON schema representing a set of tools.\nIf the user request lacks of sufficient information to make a precise tool selection: Do not invent any tool's properties, instead notify with an error message.\n\nJSON Schema:\n\"",
934944
"nullable": true
935945
},
936946
"tools": {
@@ -1071,7 +1081,10 @@
10711081
"example": "mistralai/Mistral-7B-Instruct-v0.2"
10721082
},
10731083
"prompt": {
1074-
"type": "string",
1084+
"type": "array",
1085+
"items": {
1086+
"type": "string"
1087+
},
10751088
"description": "The prompt to generate completions for.",
10761089
"example": "What is Deep Learning?"
10771090
},
@@ -1234,17 +1247,17 @@
12341247
"type": "object",
12351248
"required": [
12361249
"name",
1237-
"parameters"
1250+
"arguments"
12381251
],
12391252
"properties": {
1253+
"arguments": {},
12401254
"description": {
12411255
"type": "string",
12421256
"nullable": true
12431257
},
12441258
"name": {
12451259
"type": "string"
1246-
},
1247-
"parameters": {}
1260+
}
12481261
}
12491262
},
12501263
"GenerateParameters": {
@@ -1260,7 +1273,7 @@
12601273
},
12611274
"decoder_input_details": {
12621275
"type": "boolean",
1263-
"default": "true"
1276+
"default": "false"
12641277
},
12651278
"details": {
12661279
"type": "boolean",
@@ -1285,6 +1298,7 @@
12851298
"$ref": "#/components/schemas/GrammarType"
12861299
}
12871300
],
1301+
"default": "null",
12881302
"nullable": true
12891303
},
12901304
"max_new_tokens": {
@@ -1478,6 +1492,7 @@
14781492
"max_batch_total_tokens",
14791493
"max_waiting_tokens",
14801494
"validation_workers",
1495+
"max_client_batch_size",
14811496
"version"
14821497
],
14831498
"properties": {
@@ -1503,6 +1518,11 @@
15031518
"example": "2",
15041519
"minimum": 0
15051520
},
1521+
"max_client_batch_size": {
1522+
"type": "integer",
1523+
"example": "32",
1524+
"minimum": 0
1525+
},
15061526
"max_concurrent_requests": {
15071527
"type": "integer",
15081528
"description": "Router Parameters",

docs/source/basic_tutorials/launcher.md

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -398,6 +398,15 @@ Options:
398398
-e, --env
399399
Display a lot of information about your runtime environment
400400
401+
```
402+
## MAX_CLIENT_BATCH_SIZE
403+
```shell
404+
--max-client-batch-size <MAX_CLIENT_BATCH_SIZE>
405+
Control the maximum number of inputs that a client can send in a single request
406+
407+
[env: MAX_CLIENT_BATCH_SIZE=]
408+
[default: 4]
409+
401410
```
402411
## HELP
403412
```shell

0 commit comments

Comments
 (0)