5
5
from text_generation .types import FinishReason , InputToken
6
6
7
7
8
- def test_generate (flan_t5_xxl_url , hf_headers ):
9
- client = Client (flan_t5_xxl_url , hf_headers )
8
+ def test_generate (llama_7b_url , hf_headers ):
9
+ client = Client (llama_7b_url , hf_headers )
10
10
response = client .generate ("test" , max_new_tokens = 1 , decoder_input_details = True )
11
11
12
- assert response .generated_text == ""
12
+ assert response .generated_text == "_ "
13
13
assert response .details .finish_reason == FinishReason .Length
14
14
assert response .details .generated_tokens == 1
15
15
assert response .details .seed is None
16
- assert len (response .details .prefill ) == 1
17
- assert response .details .prefill [0 ] == InputToken (id = 0 , text = "<pad >" , logprob = None )
16
+ assert len (response .details .prefill ) == 2
17
+ assert response .details .prefill [0 ] == InputToken (id = 1 , text = "<s >" , logprob = None )
18
18
assert len (response .details .tokens ) == 1
19
- assert response .details .tokens [0 ].id == 3
20
- assert response .details .tokens [0 ].text == " "
19
+ assert response .details .tokens [0 ].id == 29918
20
+ assert response .details .tokens [0 ].text == "_ "
21
21
assert not response .details .tokens [0 ].special
22
22
23
23
24
- def test_generate_best_of (flan_t5_xxl_url , hf_headers ):
25
- client = Client (flan_t5_xxl_url , hf_headers )
24
+ def test_generate_best_of (llama_7b_url , hf_headers ):
25
+ client = Client (llama_7b_url , hf_headers )
26
26
response = client .generate (
27
27
"test" , max_new_tokens = 1 , best_of = 2 , do_sample = True , decoder_input_details = True
28
28
)
@@ -39,22 +39,22 @@ def test_generate_not_found(fake_url, hf_headers):
39
39
client .generate ("test" )
40
40
41
41
42
- def test_generate_validation_error (flan_t5_xxl_url , hf_headers ):
43
- client = Client (flan_t5_xxl_url , hf_headers )
42
+ def test_generate_validation_error (llama_7b_url , hf_headers ):
43
+ client = Client (llama_7b_url , hf_headers )
44
44
with pytest .raises (ValidationError ):
45
45
client .generate ("test" , max_new_tokens = 10_000 )
46
46
47
47
48
- def test_generate_stream (flan_t5_xxl_url , hf_headers ):
49
- client = Client (flan_t5_xxl_url , hf_headers )
48
+ def test_generate_stream (llama_7b_url , hf_headers ):
49
+ client = Client (llama_7b_url , hf_headers )
50
50
responses = [
51
51
response for response in client .generate_stream ("test" , max_new_tokens = 1 )
52
52
]
53
53
54
54
assert len (responses ) == 1
55
55
response = responses [0 ]
56
56
57
- assert response .generated_text == ""
57
+ assert response .generated_text == "_ "
58
58
assert response .details .finish_reason == FinishReason .Length
59
59
assert response .details .generated_tokens == 1
60
60
assert response .details .seed is None
@@ -66,34 +66,37 @@ def test_generate_stream_not_found(fake_url, hf_headers):
66
66
list (client .generate_stream ("test" ))
67
67
68
68
69
- def test_generate_stream_validation_error (flan_t5_xxl_url , hf_headers ):
70
- client = Client (flan_t5_xxl_url , hf_headers )
69
+ def test_generate_stream_validation_error (llama_7b_url , hf_headers ):
70
+ client = Client (llama_7b_url , hf_headers )
71
71
with pytest .raises (ValidationError ):
72
72
list (client .generate_stream ("test" , max_new_tokens = 10_000 ))
73
73
74
74
75
75
@pytest .mark .asyncio
76
- async def test_generate_async (flan_t5_xxl_url , hf_headers ):
77
- client = AsyncClient (flan_t5_xxl_url , hf_headers )
76
+ async def test_generate_async (llama_7b_url , hf_headers ):
77
+ client = AsyncClient (llama_7b_url , hf_headers )
78
78
response = await client .generate (
79
79
"test" , max_new_tokens = 1 , decoder_input_details = True
80
80
)
81
81
82
- assert response .generated_text == ""
82
+ assert response .generated_text == "_ "
83
83
assert response .details .finish_reason == FinishReason .Length
84
84
assert response .details .generated_tokens == 1
85
85
assert response .details .seed is None
86
- assert len (response .details .prefill ) == 1
87
- assert response .details .prefill [0 ] == InputToken (id = 0 , text = "<pad>" , logprob = None )
86
+ assert len (response .details .prefill ) == 2
87
+ assert response .details .prefill [0 ] == InputToken (id = 1 , text = "<s>" , logprob = None )
88
+ assert response .details .prefill [1 ] == InputToken (
89
+ id = 1243 , text = "test" , logprob = - 10.96875
90
+ )
88
91
assert len (response .details .tokens ) == 1
89
- assert response .details .tokens [0 ].id == 3
90
- assert response .details .tokens [0 ].text == " "
92
+ assert response .details .tokens [0 ].id == 29918
93
+ assert response .details .tokens [0 ].text == "_ "
91
94
assert not response .details .tokens [0 ].special
92
95
93
96
94
97
@pytest .mark .asyncio
95
- async def test_generate_async_best_of (flan_t5_xxl_url , hf_headers ):
96
- client = AsyncClient (flan_t5_xxl_url , hf_headers )
98
+ async def test_generate_async_best_of (llama_7b_url , hf_headers ):
99
+ client = AsyncClient (llama_7b_url , hf_headers )
97
100
response = await client .generate (
98
101
"test" , max_new_tokens = 1 , best_of = 2 , do_sample = True , decoder_input_details = True
99
102
)
@@ -112,23 +115,23 @@ async def test_generate_async_not_found(fake_url, hf_headers):
112
115
113
116
114
117
@pytest .mark .asyncio
115
- async def test_generate_async_validation_error (flan_t5_xxl_url , hf_headers ):
116
- client = AsyncClient (flan_t5_xxl_url , hf_headers )
118
+ async def test_generate_async_validation_error (llama_7b_url , hf_headers ):
119
+ client = AsyncClient (llama_7b_url , hf_headers )
117
120
with pytest .raises (ValidationError ):
118
121
await client .generate ("test" , max_new_tokens = 10_000 )
119
122
120
123
121
124
@pytest .mark .asyncio
122
- async def test_generate_stream_async (flan_t5_xxl_url , hf_headers ):
123
- client = AsyncClient (flan_t5_xxl_url , hf_headers )
125
+ async def test_generate_stream_async (llama_7b_url , hf_headers ):
126
+ client = AsyncClient (llama_7b_url , hf_headers )
124
127
responses = [
125
128
response async for response in client .generate_stream ("test" , max_new_tokens = 1 )
126
129
]
127
130
128
131
assert len (responses ) == 1
129
132
response = responses [0 ]
130
133
131
- assert response .generated_text == ""
134
+ assert response .generated_text == "_ "
132
135
assert response .details .finish_reason == FinishReason .Length
133
136
assert response .details .generated_tokens == 1
134
137
assert response .details .seed is None
@@ -143,8 +146,8 @@ async def test_generate_stream_async_not_found(fake_url, hf_headers):
143
146
144
147
145
148
@pytest .mark .asyncio
146
- async def test_generate_stream_async_validation_error (flan_t5_xxl_url , hf_headers ):
147
- client = AsyncClient (flan_t5_xxl_url , hf_headers )
149
+ async def test_generate_stream_async_validation_error (llama_7b_url , hf_headers ):
150
+ client = AsyncClient (llama_7b_url , hf_headers )
148
151
with pytest .raises (ValidationError ):
149
152
async for _ in client .generate_stream ("test" , max_new_tokens = 10_000 ):
150
153
pass
0 commit comments