Skip to content

Commit e8184c9

Browse files
authored
Fix DispatchTrie validation and reporting of invalid routes (#70)
This regressed in #52, resulting in both false positives (where a `GET` and a `POST` shared the same route, giving an unnecessary error) and false negatives (where multiple `GET`s sharing the same route failed to create an error). The basic problem was that since combining the various HTTP methods into a single routing trie, the old logic comparing uniqueness/duplication/etc. was no longer correct in the new combined trie. This PR fixes it by doing a `groupBy` to split up the entries in the combined trie by HTTP method, before running essentially the same validation. We augment the test suite, tightening up cask/test/src/test/cask/DispatchTrieTests.scala to make it stricter, checking exact error messages to ensure we get not just any failure but the *correct* failure when the validation code triggers. This should hopefully catch this sort of regression in future.
1 parent d6ef66a commit e8184c9

File tree

5 files changed

+148
-27
lines changed

5 files changed

+148
-27
lines changed

cask/src/cask/internal/DispatchTrie.scala

Lines changed: 60 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
package cask.internal
22
import collection.mutable
33
object DispatchTrie{
4-
def construct[T](index: Int,
5-
inputs: collection.Seq[(collection.IndexedSeq[String], T, Boolean)]): DispatchTrie[T] = {
4+
def construct[T, V](index: Int,
5+
inputs: collection.Seq[(collection.IndexedSeq[String], T, Boolean)])
6+
(validationGroups: T => Seq[V]): DispatchTrie[T] = {
67
val continuations = mutable.Map.empty[String, mutable.Buffer[(collection.IndexedSeq[String], T, Boolean)]]
78

89
val terminals = mutable.Buffer.empty[(collection.IndexedSeq[String], T, Boolean)]
@@ -17,29 +18,67 @@ object DispatchTrie{
1718
}
1819
}
1920

21+
for(group <- inputs.flatMap(t => validationGroups(t._2)).distinct) {
22+
val groupTerminals = terminals.flatMap{case (path, v, allowSubpath) =>
23+
validationGroups(v)
24+
.filter(_ == group)
25+
.map{group => (path, v, allowSubpath, group)}
26+
}
27+
28+
val groupContinuations = continuations
29+
.map { case (k, vs) =>
30+
k -> vs.flatMap { case (path, v, allowSubpath) =>
31+
validationGroups(v)
32+
.filter(_ == group)
33+
.map { group => (path, v, allowSubpath, group) }
34+
}
35+
}
36+
.filter(_._2.nonEmpty)
37+
38+
validateGroup(groupTerminals, groupContinuations)
39+
}
40+
41+
DispatchTrie[T](
42+
current = terminals.headOption.map(x => x._2 -> x._3),
43+
children = continuations
44+
.map{ case (k, vs) => (k, construct(index + 1, vs)(validationGroups))}
45+
.toMap
46+
)
47+
}
48+
49+
def validateGroup[T, V](terminals: collection.Seq[(collection.Seq[String], T, Boolean, V)],
50+
continuations: mutable.Map[String, mutable.Buffer[(collection.IndexedSeq[String], T, Boolean, V)]]) = {
2051
val wildcards = continuations.filter(_._1(0) == ':')
21-
if (terminals.length > 1){
52+
53+
def renderTerminals = terminals
54+
.map{case (path, v, allowSubpath, group) => s"$group${renderPath(path)}"}
55+
.mkString(", ")
56+
57+
def renderContinuations = continuations.toSeq
58+
.flatMap(_._2)
59+
.map{case (path, v, allowSubpath, group) => s"$group${renderPath(path)}"}
60+
.mkString(", ")
61+
62+
if (terminals.length > 1) {
2263
throw new Exception(
23-
"More than one endpoint has the same path: " +
24-
terminals.map(_._1.map(_.mkString("/"))).mkString(", ")
64+
s"More than one endpoint has the same path: $renderTerminals"
2565
)
26-
} else if(wildcards.size >= 1 && continuations.size > 1) {
66+
}
67+
68+
if (wildcards.size >= 1 && continuations.size > 1) {
2769
throw new Exception(
28-
"Routes overlap with wildcards: " +
29-
(wildcards ++ continuations).flatMap(_._2).map(_._1.mkString("/"))
70+
s"Routes overlap with wildcards: $renderContinuations"
3071
)
31-
}else if (terminals.headOption.exists(_._3) && continuations.size == 1){
72+
}
73+
74+
if (terminals.headOption.exists(_._3) && continuations.size == 1) {
3275
throw new Exception(
33-
"Routes overlap with subpath capture: " +
34-
(wildcards ++ continuations).flatMap(_._2).map(_._1.mkString("/"))
35-
)
36-
}else{
37-
DispatchTrie[T](
38-
current = terminals.headOption.map(x => x._2 -> x._3),
39-
children = continuations.map{ case (k, vs) => (k, construct(index + 1, vs))}.toMap
76+
s"Routes overlap with subpath capture: $renderTerminals, $renderContinuations"
4077
)
4178
}
4279
}
80+
81+
def renderPath(p: collection.Seq[String]) = " /" + p.mkString("/")
4382
}
4483

4584
/**
@@ -72,4 +111,9 @@ case class DispatchTrie[T](current: Option[(T, Boolean)],
72111

73112
}
74113
}
114+
115+
def map[V](f: T => V): DispatchTrie[V] = DispatchTrie(
116+
current.map{case (t, v) => (f(t), v)},
117+
children.map { case (k, v) => (k, v.map(f))}
118+
)
75119
}

cask/src/cask/main/Main.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -154,12 +154,12 @@ object Main{
154154
}
155155

156156
val dispatchInputs = flattenedRoutes.groupBy(_._1).map { case (segments, values) =>
157-
val methodMap = values.map(_._2).flatten.toMap
157+
val methodMap = values.map(_._2).flatten
158158
val hasSubpath = values.map(_._3).contains(true)
159159
(segments, methodMap, hasSubpath)
160160
}.toSeq
161161

162-
DispatchTrie.construct(0, dispatchInputs)
162+
DispatchTrie.construct(0, dispatchInputs)(_.map(_._1)).map(_.toMap)
163163
}
164164

165165
def writeResponse(exchange: HttpServerExchange, response: Response.Raw) = {

cask/test/src/test/cask/DispatchTrieTests.scala

Lines changed: 78 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ object DispatchTrieTests extends TestSuite {
88
"hello" - {
99
val x = DispatchTrie.construct(0,
1010
Seq((Vector("hello"), 1, false))
11-
)
11+
)(Seq(_))
1212

1313
assert(
1414
x.lookup(List("hello"), Map()) == Some((1, Map(), Nil)),
@@ -22,7 +22,7 @@ object DispatchTrieTests extends TestSuite {
2222
(Vector("hello", "world"), 1, false),
2323
(Vector("hello", "cow"), 2, false)
2424
)
25-
)
25+
)(Seq(_))
2626
assert(
2727
x.lookup(List("hello", "world"), Map()) == Some((1, Map(), Nil)),
2828
x.lookup(List("hello", "cow"), Map()) == Some((2, Map(), Nil)),
@@ -34,7 +34,7 @@ object DispatchTrieTests extends TestSuite {
3434
"bindings" - {
3535
val x = DispatchTrie.construct(0,
3636
Seq((Vector(":hello", ":world"), 1, false))
37-
)
37+
)(Seq(_))
3838
assert(
3939
x.lookup(List("hello", "world"), Map()) == Some((1, Map("hello" -> "hello", "world" -> "world"), Nil)),
4040
x.lookup(List("world", "hello"), Map()) == Some((1, Map("hello" -> "world", "world" -> "hello"), Nil)),
@@ -47,7 +47,7 @@ object DispatchTrieTests extends TestSuite {
4747
"path" - {
4848
val x = DispatchTrie.construct(0,
4949
Seq((Vector("hello"), 1, true))
50-
)
50+
)(Seq(_))
5151

5252
assert(
5353
x.lookup(List("hello", "world"), Map()) == Some((1,Map(), Seq("world"))),
@@ -58,44 +58,113 @@ object DispatchTrieTests extends TestSuite {
5858
}
5959

6060
"errors" - {
61-
intercept[Exception]{
61+
test - {
6262
DispatchTrie.construct(0,
6363
Seq(
6464
(Vector("hello", ":world"), 1, false),
6565
(Vector("hello", "world"), 2, false)
6666
)
67+
)(Seq(_))
68+
69+
val ex = intercept[Exception]{
70+
DispatchTrie.construct(0,
71+
Seq(
72+
(Vector("hello", ":world"), 1, false),
73+
(Vector("hello", "world"), 1, false)
74+
)
75+
)(Seq(_))
76+
}
77+
78+
assert(
79+
ex.getMessage ==
80+
"Routes overlap with wildcards: 1 /hello/:world, 1 /hello/world"
6781
)
6882
}
69-
intercept[Exception]{
83+
test - {
7084
DispatchTrie.construct(0,
7185
Seq(
7286
(Vector("hello", ":world"), 1, false),
7387
(Vector("hello", "world", "omg"), 2, false)
7488
)
89+
)(Seq(_))
90+
91+
val ex = intercept[Exception]{
92+
DispatchTrie.construct(0,
93+
Seq(
94+
(Vector("hello", ":world"), 1, false),
95+
(Vector("hello", "world", "omg"), 1, false)
96+
)
97+
)(Seq(_))
98+
}
99+
100+
assert(
101+
ex.getMessage ==
102+
"Routes overlap with wildcards: 1 /hello/:world, 1 /hello/world/omg"
75103
)
76104
}
77-
intercept[Exception]{
105+
test - {
78106
DispatchTrie.construct(0,
79107
Seq(
80108
(Vector("hello"), 1, true),
81109
(Vector("hello", "cow", "omg"), 2, false)
82110
)
111+
)(Seq(_))
112+
113+
val ex = intercept[Exception]{
114+
DispatchTrie.construct(0,
115+
Seq(
116+
(Vector("hello"), 1, true),
117+
(Vector("hello", "cow", "omg"), 1, false)
118+
)
119+
)(Seq(_))
120+
}
121+
122+
assert(
123+
ex.getMessage ==
124+
"Routes overlap with subpath capture: 1 /hello, 1 /hello/cow/omg"
83125
)
84126
}
85-
intercept[Exception]{
127+
test - {
86128
DispatchTrie.construct(0,
87129
Seq(
88130
(Vector("hello", ":world"), 1, false),
89131
(Vector("hello", ":cow"), 2, false)
90132
)
133+
)(Seq(_))
134+
135+
val ex = intercept[Exception]{
136+
DispatchTrie.construct(0,
137+
Seq(
138+
(Vector("hello", ":world"), 1, false),
139+
(Vector("hello", ":cow"), 1, false)
140+
)
141+
)(Seq(_))
142+
}
143+
144+
assert(
145+
ex.getMessage ==
146+
"Routes overlap with wildcards: 1 /hello/:world, 1 /hello/:cow"
91147
)
92148
}
93-
intercept[Exception]{
149+
test - {
94150
DispatchTrie.construct(0,
95151
Seq(
96152
(Vector("hello", "world"), 1, false),
97153
(Vector("hello", "world"), 2, false)
98154
)
155+
)(Seq(_))
156+
157+
val ex = intercept[Exception]{
158+
DispatchTrie.construct(0,
159+
Seq(
160+
(Vector("hello", "world"), 1, false),
161+
(Vector("hello", "world"), 1, false)
162+
)
163+
)(Seq(_))
164+
}
165+
assert(
166+
ex.getMessage ==
167+
"More than one endpoint has the same path: 1 /hello/world, 1 /hello/world"
99168
)
100169
}
101170
}

example/variableRoutes/app/src/VariableRoutes.scala

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,5 +15,10 @@ object VariableRoutes extends cask.MainRoutes{
1515
s"Subpath ${request.remainingPathSegments}"
1616
}
1717

18+
@cask.post("/path", subpath = true)
19+
def postShowSubpath(request: cask.Request) = {
20+
s"POST Subpath ${request.remainingPathSegments}"
21+
}
22+
1823
initialize()
1924
}

example/variableRoutes/app/test/src/ExampleTests.scala

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,9 @@ object ExampleTests extends TestSuite{
4646

4747
requests.get(s"$host/path/one/two/three").text() ==>
4848
"Subpath List(one, two, three)"
49+
50+
requests.post(s"$host/path/one/two/three").text() ==>
51+
"POST Subpath List(one, two, three)"
4952
}
5053

5154
}

0 commit comments

Comments
 (0)