Skip to content

Commit cd1fdae

Browse files
authored
Fix case class default parameter detection and annotation detection in Scala3 (#88)
Fixes #33, as well as `@arg` annotations not working on `case class` arguments (which doesn't have an open issue) Just need some special casing for `apply`/`<init>`, both for getting the `apply` defaults (which are named `<init>$default$n` since they're shared) and for getting the `apply` parameter annotations (which end up on the `<init>` method parameters, rather than the `apply` method) Re-enabled a bunch of previously-disabled tests, and added a new suite to exercise default parameters in direct/nested scenarios
1 parent bd7b871 commit cd1fdae

File tree

6 files changed

+190
-111
lines changed

6 files changed

+190
-111
lines changed

build.sc

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,8 @@ trait MainArgsPublishModule
5555
)
5656

5757
def ivyDeps = Agg(
58-
ivy"org.scala-lang.modules::scala-collection-compat::2.8.1"
58+
ivy"org.scala-lang.modules::scala-collection-compat::2.8.1",
59+
ivy"com.lihaoyi::pprint:0.8.1"
5960
)
6061
}
6162

mainargs/src-3/Macros.scala

Lines changed: 35 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -40,18 +40,37 @@ object Macros {
4040
val annotatedMethod = TypeRepr.of[B].typeSymbol.companionModule.memberMethod("apply").head
4141
companionModuleType match
4242
case '[bCompanion] =>
43-
val mainData = createMainData[B, Any](annotatedMethod, mainAnnotationInstance)
43+
val mainData = createMainData[B, Any](
44+
annotatedMethod,
45+
mainAnnotationInstance,
46+
// Somehow the `apply` method parameter annotations don't end up on
47+
// the `apply` method parameters, but end up in the `<init>` method
48+
// parameters, so use those for getting the annotations instead
49+
TypeRepr.of[B].typeSymbol.primaryConstructor.paramSymss
50+
)
4451
'{ new ParserForClass[B](${ mainData }, () => ${ Ident(companionModule).asExpr }) }
4552
}
4653

47-
def createMainData[T: Type, B: Type](using Quotes)(method: quotes.reflect.Symbol, annotation: quotes.reflect.Term): Expr[MainData[T, B]] = {
54+
def createMainData[T: Type, B: Type](using Quotes)
55+
(method: quotes.reflect.Symbol,
56+
mainAnnotation: quotes.reflect.Term): Expr[MainData[T, B]] = {
57+
createMainData[T, B](method, mainAnnotation, method.paramSymss)
58+
}
59+
60+
def createMainData[T: Type, B: Type](using Quotes)
61+
(method: quotes.reflect.Symbol,
62+
mainAnnotation: quotes.reflect.Term,
63+
annotatedParamsLists: List[List[quotes.reflect.Symbol]]): Expr[MainData[T, B]] = {
64+
4865
import quotes.reflect.*
4966
val params = method.paramSymss.headOption.getOrElse(report.throwError("Multiple parameter lists not supported"))
5067
val defaultParams = getDefaultParams(method)
51-
val argSigs = Expr.ofList(params.map { param =>
68+
val argSigsExprs = params.zip(annotatedParamsLists.flatten).map { paramAndAnnotParam =>
69+
val param = paramAndAnnotParam._1
70+
val annotParam = paramAndAnnotParam._2
5271
val paramTree = param.tree.asInstanceOf[ValDef]
5372
val paramTpe = paramTree.tpt.tpe
54-
val arg = param.getAnnotation(argAnnotation).map(_.asExprOf[mainargs.arg]).getOrElse('{ new mainargs.arg() })
73+
val arg = annotParam.getAnnotation(argAnnotation).map(_.asExprOf[mainargs.arg]).getOrElse('{ new mainargs.arg() })
5574
val paramType = paramTpe.asType
5675
paramType match
5776
case '[t] =>
@@ -66,13 +85,14 @@ object Macros {
6685
)
6786
}
6887
'{ (ArgSig.create[t, B](${ Expr(param.name) }, ${ arg }, ${ defaultParam })(using ${ tokensReader })) }
69-
})
88+
}
89+
val argSigs = Expr.ofList(argSigsExprs)
7090

7191
val invokeRaw: Expr[(B, Seq[Any]) => T] = {
7292
def callOf(args: Expr[Seq[Any]]) = call(method, '{ Seq( ${ args }) }).asExprOf[T]
7393
'{ ((b: B, params: Seq[Any]) => ${ callOf('{ params }) }) }
7494
}
75-
'{ MainData.create[T, B](${ Expr(method.name) }, ${ annotation.asExprOf[mainargs.main] }, ${ argSigs }, ${ invokeRaw }) }
95+
'{ MainData.create[T, B](${ Expr(method.name) }, ${ mainAnnotation.asExprOf[mainargs.main] }, ${ argSigs }, ${ invokeRaw }) }
7696
}
7797

7898
/** Call a method given by its symbol.
@@ -134,12 +154,21 @@ object Macros {
134154
val defaults = collection.mutable.Map.empty[Symbol, Expr[Any]]
135155

136156
val Name = (method.name + """\$default\$(\d+)""").r
157+
val InitName = """\$lessinit\$greater\$default\$(\d+)""".r
137158

138159
val idents = method.owner.tree.asInstanceOf[ClassDef].body
160+
139161
idents.foreach{
140162
case deff @ DefDef(Name(idx), _, _, _) =>
141163
val expr = Ref(deff.symbol).asExpr
142164
defaults += (params(idx.toInt - 1) -> expr)
165+
166+
// The `apply` method re-uses the default param factory methods from `<init>`,
167+
// so make sure to check if those exist too
168+
case deff @ DefDef(InitName(idx), _, _, _) if method.name == "apply" =>
169+
val expr = Ref(deff.symbol).asExpr
170+
defaults += (params(idx.toInt - 1) -> expr)
171+
143172
case _ =>
144173
}
145174

mainargs/test/src/ClassTests.scala

Lines changed: 86 additions & 99 deletions
Original file line numberDiff line numberDiff line change
@@ -58,121 +58,108 @@ object ClassTests extends TestSuite {
5858
Bar(Flag(true), Foo(1, 2), "xxx")
5959
}
6060
test("missingInner") {
61-
// Blocked by https://github.com/lampepfl/dotty/issues/12492
62-
TestUtils.scala2Only {
63-
barParser.constructRaw(Seq("-w", "-x", "1", "-z", "xxx")) ==>
64-
Result.Failure.MismatchedArguments(
65-
Seq(
66-
ArgSig(
67-
None,
68-
Some('y'),
69-
None,
70-
None,
71-
mainargs.TokensReader.IntRead,
72-
positional = false,
73-
hidden = false
74-
)
75-
),
76-
List(),
77-
List(),
78-
None
79-
)
80-
}
61+
barParser.constructRaw(Seq("-w", "-x", "1", "-z", "xxx")) ==>
62+
Result.Failure.MismatchedArguments(
63+
Seq(
64+
ArgSig(
65+
None,
66+
Some('y'),
67+
None,
68+
None,
69+
mainargs.TokensReader.IntRead,
70+
positional = false,
71+
hidden = false
72+
)
73+
),
74+
List(),
75+
List(),
76+
None
77+
)
8178
}
8279
test("missingOuter") {
83-
// Blocked by https://github.com/lampepfl/dotty/issues/12492
84-
TestUtils.scala2Only {
85-
barParser.constructRaw(Seq("-w", "-x", "1", "-y", "2")) ==>
86-
Result.Failure.MismatchedArguments(
87-
Seq(
88-
ArgSig(
89-
Some("zzzz"),
90-
Some('z'),
91-
None,
92-
None,
93-
mainargs.TokensReader.StringRead,
94-
positional = false,
95-
hidden = false
96-
)
97-
),
98-
List(),
99-
List(),
100-
None
101-
)
102-
}
80+
barParser.constructRaw(Seq("-w", "-x", "1", "-y", "2")) ==>
81+
Result.Failure.MismatchedArguments(
82+
Seq(
83+
ArgSig(
84+
Some("zzzz"),
85+
Some('z'),
86+
None,
87+
None,
88+
mainargs.TokensReader.StringRead,
89+
positional = false,
90+
hidden = false
91+
)
92+
),
93+
List(),
94+
List(),
95+
None
96+
)
10397
}
10498

10599
test("missingInnerOuter") {
106-
// Blocked by https://github.com/lampepfl/dotty/issues/12492
107-
TestUtils.scala2Only {
108-
barParser.constructRaw(Seq("-w", "-x", "1")) ==>
109-
Result.Failure.MismatchedArguments(
110-
Seq(
111-
ArgSig(
112-
None,
113-
Some('y'),
114-
None,
115-
None,
116-
mainargs.TokensReader.IntRead,
117-
positional = false,
118-
hidden = false
119-
),
120-
ArgSig(
121-
Some("zzzz"),
122-
Some('z'),
123-
None,
124-
None,
125-
mainargs.TokensReader.StringRead,
126-
positional = false,
127-
hidden = false
128-
)
100+
barParser.constructRaw(Seq("-w", "-x", "1")) ==>
101+
Result.Failure.MismatchedArguments(
102+
Seq(
103+
ArgSig(
104+
None,
105+
Some('y'),
106+
None,
107+
None,
108+
mainargs.TokensReader.IntRead,
109+
positional = false,
110+
hidden = false
129111
),
130-
List(),
131-
List(),
132-
None
133-
)
134-
}
112+
ArgSig(
113+
Some("zzzz"),
114+
Some('z'),
115+
None,
116+
None,
117+
mainargs.TokensReader.StringRead,
118+
positional = false,
119+
hidden = false
120+
)
121+
),
122+
List(),
123+
List(),
124+
None
125+
)
135126
}
127+
136128
test("failedInnerOuter") {
137-
TestUtils.scala2Only {
138-
assertMatch(
139-
barParser.constructRaw(
140-
Seq("-w", "-x", "xxx", "-y", "hohoho", "-z", "xxx")
141-
)
142-
) {
143-
case Result.Failure.InvalidArguments(
144-
Seq(
145-
Result.ParamError.Failed(
146-
ArgSig(None, Some('x'), None, None, _, false, _),
147-
Seq("xxx"),
148-
_
149-
),
150-
Result.ParamError.Failed(
151-
ArgSig(None, Some('y'), None, None, _, false, _),
152-
Seq("hohoho"),
153-
_
154-
)
129+
assertMatch(
130+
barParser.constructRaw(
131+
Seq("-w", "-x", "xxx", "-y", "hohoho", "-z", "xxx")
132+
)
133+
) {
134+
case Result.Failure.InvalidArguments(
135+
Seq(
136+
Result.ParamError.Failed(
137+
ArgSig(None, Some('x'), None, None, _, false, _),
138+
Seq("xxx"),
139+
_
140+
),
141+
Result.ParamError.Failed(
142+
ArgSig(None, Some('y'), None, None, _, false, _),
143+
Seq("hohoho"),
144+
_
155145
)
156-
) =>
157-
}
146+
)
147+
) =>
148+
158149
}
159150
}
160151
}
161152

162153
test("doubleNested") {
163-
TestUtils.scala2Only {
164-
quxParser.constructOrThrow(
165-
Seq("-w", "-x", "1", "-y", "2", "-z", "xxx", "--moo", "cow")
166-
) ==>
167-
Qux("cow", Bar(Flag(true), Foo(1, 2), "xxx"))
168-
}
154+
quxParser.constructOrThrow(
155+
Seq("-w", "-x", "1", "-y", "2", "-z", "xxx", "--moo", "cow")
156+
) ==>
157+
Qux("cow", Bar(Flag(true), Foo(1, 2), "xxx"))
169158
}
170159
test("success") {
171-
TestUtils.scala2Only {
172-
ParserForMethods(Main).runOrThrow(
173-
Seq("-x", "1", "-y", "2", "-z", "hello")
174-
) ==> "false 1 2 hello false"
175-
}
160+
ParserForMethods(Main).runOrThrow(
161+
Seq("-x", "1", "-y", "2", "-z", "hello")
162+
) ==> "false 1 2 hello false"
176163
}
177164
}
178165
}
Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
package mainargs
2+
import utest._
3+
4+
// Make sure
5+
object ClassWithDefaultTests extends TestSuite {
6+
@main
7+
case class Foo(x: Int, y: Int = 1)
8+
9+
implicit val fooParser: ParserForClass[Foo] = ParserForClass[Foo]
10+
11+
object Main {
12+
@main
13+
def run(foo: Foo, bool: Boolean = false) = s"${foo.x} ${foo.y} $bool"
14+
}
15+
16+
val mainParser = ParserForMethods(Main)
17+
18+
val tests = Tests {
19+
test("simple") {
20+
test("success") {
21+
fooParser.constructOrThrow(Seq("-x", "1", "-y", "2")) ==> Foo(1, 2)
22+
}
23+
test("default") {
24+
fooParser.constructOrThrow(Seq("-x", "0")) ==> Foo(0, 1)
25+
}
26+
test("missing") {
27+
fooParser.constructRaw(Seq()) ==>
28+
Result.Failure.MismatchedArguments(
29+
Seq(
30+
ArgSig(
31+
None,
32+
Some('x'),
33+
None,
34+
None,
35+
mainargs.TokensReader.IntRead,
36+
positional = false,
37+
hidden = false
38+
)
39+
),
40+
List(),
41+
List(),
42+
None
43+
)
44+
45+
}
46+
}
47+
48+
test("nested") {
49+
test("success"){
50+
mainParser.runOrThrow(Seq("-x", "1", "-y", "2", "--bool", "true")) ==> "1 2 true"
51+
}
52+
test("default"){
53+
mainParser.runOrThrow(Seq("-x", "1", "-y", "2")) ==> "1 2 false"
54+
}
55+
test("default2"){
56+
mainParser.runOrThrow(Seq("-x", "0")) ==> "0 1 false"
57+
}
58+
}
59+
}
60+
}

mainargs/test/src/ParserTests.scala

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -51,11 +51,8 @@ object ParserTests extends TestSuite {
5151
) ==> Right("xxxxx")
5252
}
5353
test("constructEither") {
54-
TestUtils.scala2Only {
55-
// default values in classes not working on Scala 3
56-
classParser.constructEither(Array("--code", "println(1)")) ==>
57-
Right(ClassBase(code = Some("println(1)"), other = "hello"))
58-
}
54+
classParser.constructEither(Array("--code", "println(1)")) ==>
55+
Right(ClassBase(code = Some("println(1)"), other = "hello"))
5956
}
6057
}
6158
}

readme.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -519,6 +519,11 @@ command-line friendly tool.
519519

520520
# Changelog
521521

522+
## master
523+
524+
- Fix handling of case class main method parameter default parameters in Scala 3
525+
[#88](https://github.com/com-lihaoyi/mainargs/pull/88)
526+
522527
## 0.5.0
523528

524529
- Remove hard-code support for mainargs.Leftover/Flag/Subparser to support

0 commit comments

Comments
 (0)