Skip to content

Commit bc4c920

Browse files
efectnReneWerner87
andauthored
bind: add support for multipart file binding (#3309)
* deps: update schema to v1.3.0 * bind: add support for multipart file binding * bind: fix linter * improve coverage * fix linter * add test cases --------- Co-authored-by: René <[email protected]>
1 parent d6d48d8 commit bc4c920

File tree

6 files changed

+294
-13
lines changed

6 files changed

+294
-13
lines changed

binder/form.go

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
package binder
22

33
import (
4+
"mime/multipart"
5+
46
"github.com/gofiber/utils/v2"
57
"github.com/valyala/fasthttp"
68
)
@@ -59,7 +61,15 @@ func (b *FormBinding) bindMultipart(req *fasthttp.Request, out any) error {
5961
}
6062
}
6163

62-
return parse(b.Name(), out, data)
64+
files := make(map[string][]*multipart.FileHeader)
65+
for key, values := range multipartForm.File {
66+
err = formatBindData(out, files, key, values, b.EnableSplitting, true)
67+
if err != nil {
68+
return err
69+
}
70+
}
71+
72+
return parse(b.Name(), out, data, files)
6373
}
6474

6575
// Reset resets the FormBinding binder.

binder/form_test.go

Lines changed: 58 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package binder
22

33
import (
44
"bytes"
5+
"io"
56
"mime/multipart"
67
"testing"
78

@@ -98,10 +99,12 @@ func Test_FormBinder_BindMultipart(t *testing.T) {
9899
}
99100

100101
type User struct {
101-
Name string `form:"name"`
102-
Names []string `form:"names"`
103-
Posts []Post `form:"posts"`
104-
Age int `form:"age"`
102+
Avatar *multipart.FileHeader `form:"avatar"`
103+
Name string `form:"name"`
104+
Names []string `form:"names"`
105+
Posts []Post `form:"posts"`
106+
Avatars []*multipart.FileHeader `form:"avatars"`
107+
Age int `form:"age"`
105108
}
106109
var user User
107110

@@ -118,6 +121,24 @@ func Test_FormBinder_BindMultipart(t *testing.T) {
118121
require.NoError(t, mw.WriteField("posts[1][title]", "post2"))
119122
require.NoError(t, mw.WriteField("posts[2][title]", "post3"))
120123

124+
writer, err := mw.CreateFormFile("avatar", "avatar.txt")
125+
require.NoError(t, err)
126+
127+
_, err = writer.Write([]byte("avatar"))
128+
require.NoError(t, err)
129+
130+
writer, err = mw.CreateFormFile("avatars", "avatar1.txt")
131+
require.NoError(t, err)
132+
133+
_, err = writer.Write([]byte("avatar1"))
134+
require.NoError(t, err)
135+
136+
writer, err = mw.CreateFormFile("avatars", "avatar2.txt")
137+
require.NoError(t, err)
138+
139+
_, err = writer.Write([]byte("avatar2"))
140+
require.NoError(t, err)
141+
121142
require.NoError(t, mw.Close())
122143

123144
req.Header.SetContentType(mw.FormDataContentType())
@@ -127,7 +148,7 @@ func Test_FormBinder_BindMultipart(t *testing.T) {
127148
fasthttp.ReleaseRequest(req)
128149
})
129150

130-
err := b.Bind(req, &user)
151+
err = b.Bind(req, &user)
131152

132153
require.NoError(t, err)
133154
require.Equal(t, "john", user.Name)
@@ -139,6 +160,38 @@ func Test_FormBinder_BindMultipart(t *testing.T) {
139160
require.Equal(t, "post1", user.Posts[0].Title)
140161
require.Equal(t, "post2", user.Posts[1].Title)
141162
require.Equal(t, "post3", user.Posts[2].Title)
163+
164+
require.NotNil(t, user.Avatar)
165+
require.Equal(t, "avatar.txt", user.Avatar.Filename)
166+
require.Equal(t, "application/octet-stream", user.Avatar.Header.Get("Content-Type"))
167+
168+
file, err := user.Avatar.Open()
169+
require.NoError(t, err)
170+
171+
content, err := io.ReadAll(file)
172+
require.NoError(t, err)
173+
require.Equal(t, "avatar", string(content))
174+
175+
require.Len(t, user.Avatars, 2)
176+
require.Equal(t, "avatar1.txt", user.Avatars[0].Filename)
177+
require.Equal(t, "application/octet-stream", user.Avatars[0].Header.Get("Content-Type"))
178+
179+
file, err = user.Avatars[0].Open()
180+
require.NoError(t, err)
181+
182+
content, err = io.ReadAll(file)
183+
require.NoError(t, err)
184+
require.Equal(t, "avatar1", string(content))
185+
186+
require.Equal(t, "avatar2.txt", user.Avatars[1].Filename)
187+
require.Equal(t, "application/octet-stream", user.Avatars[1].Header.Get("Content-Type"))
188+
189+
file, err = user.Avatars[1].Open()
190+
require.NoError(t, err)
191+
192+
content, err = io.ReadAll(file)
193+
require.NoError(t, err)
194+
require.Equal(t, "avatar2", string(content))
142195
}
143196

144197
func Benchmark_FormBinder_BindMultipart(b *testing.B) {

binder/mapping.go

Lines changed: 26 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package binder
33
import (
44
"errors"
55
"fmt"
6+
"mime/multipart"
67
"reflect"
78
"strings"
89
"sync"
@@ -69,7 +70,7 @@ func init() {
6970
}
7071

7172
// parse data into the map or struct
72-
func parse(aliasTag string, out any, data map[string][]string) error {
73+
func parse(aliasTag string, out any, data map[string][]string, files ...map[string][]*multipart.FileHeader) error {
7374
ptrVal := reflect.ValueOf(out)
7475

7576
// Get pointer value
@@ -83,19 +84,19 @@ func parse(aliasTag string, out any, data map[string][]string) error {
8384
}
8485

8586
// Parse into the struct
86-
return parseToStruct(aliasTag, out, data)
87+
return parseToStruct(aliasTag, out, data, files...)
8788
}
8889

8990
// Parse data into the struct with gorilla/schema
90-
func parseToStruct(aliasTag string, out any, data map[string][]string) error {
91+
func parseToStruct(aliasTag string, out any, data map[string][]string, files ...map[string][]*multipart.FileHeader) error {
9192
// Get decoder from pool
9293
schemaDecoder := decoderPoolMap[aliasTag].Get().(*schema.Decoder) //nolint:errcheck,forcetypeassert // not needed
9394
defer decoderPoolMap[aliasTag].Put(schemaDecoder)
9495

9596
// Set alias tag
9697
schemaDecoder.SetAliasTag(aliasTag)
9798

98-
if err := schemaDecoder.Decode(out, data); err != nil {
99+
if err := schemaDecoder.Decode(out, data, files...); err != nil {
99100
return fmt.Errorf("bind: %w", err)
100101
}
101102

@@ -250,7 +251,7 @@ func FilterFlags(content string) string {
250251
return content
251252
}
252253

253-
func formatBindData[T any](out any, data map[string][]string, key string, value T, enableSplitting, supportBracketNotation bool) error { //nolint:revive // it's okay
254+
func formatBindData[T, K any](out any, data map[string][]T, key string, value K, enableSplitting, supportBracketNotation bool) error { //nolint:revive // it's okay
254255
var err error
255256
if supportBracketNotation && strings.Contains(key, "[") {
256257
key, err = parseParamSquareBrackets(key)
@@ -261,10 +262,28 @@ func formatBindData[T any](out any, data map[string][]string, key string, value
261262

262263
switch v := any(value).(type) {
263264
case string:
264-
assignBindData(out, data, key, v, enableSplitting)
265+
dataMap, ok := any(data).(map[string][]string)
266+
if !ok {
267+
return fmt.Errorf("unsupported value type: %T", value)
268+
}
269+
270+
assignBindData(out, dataMap, key, v, enableSplitting)
265271
case []string:
272+
dataMap, ok := any(data).(map[string][]string)
273+
if !ok {
274+
return fmt.Errorf("unsupported value type: %T", value)
275+
}
276+
277+
for _, val := range v {
278+
assignBindData(out, dataMap, key, val, enableSplitting)
279+
}
280+
case []*multipart.FileHeader:
266281
for _, val := range v {
267-
assignBindData(out, data, key, val, enableSplitting)
282+
valT, ok := any(val).(T)
283+
if !ok {
284+
return fmt.Errorf("unsupported value type: %T", value)
285+
}
286+
data[key] = append(data[key], valT)
268287
}
269288
default:
270289
return fmt.Errorf("unsupported value type: %T", value)

0 commit comments

Comments
 (0)