@@ -49,27 +49,113 @@ class MarkDeviceFunction : public RecursiveASTVisitor<MarkDeviceFunction> {
49
49
MarkDeviceFunction (Sema &S)
50
50
: RecursiveASTVisitor<MarkDeviceFunction>(), SemaRef(S) {}
51
51
bool VisitCallExpr (CallExpr *e) {
52
+ for (const auto &Arg : e->arguments ())
53
+ CheckTypeForVirtual (Arg->getType (), Arg->getSourceRange ());
54
+
52
55
if (FunctionDecl *Callee = e->getDirectCallee ()) {
53
56
// Remember that all SYCL kernel functions have deferred
54
57
// instantiation as template functions. It means that
55
58
// all functions used by kernel have already been parsed and have
56
59
// definitions.
60
+
61
+ CheckTypeForVirtual (Callee->getReturnType (), Callee->getSourceRange ());
62
+
57
63
if (FunctionDecl *Def = Callee->getDefinition ()) {
58
64
if (!Def->hasAttr <SYCLDeviceAttr>()) {
59
65
Def->addAttr (SYCLDeviceAttr::CreateImplicit (SemaRef.Context ));
60
66
this ->TraverseStmt (Def->getBody ());
61
- // But because parser works with top level declarations and CodeGen
62
- // already saw and ignored our function without device attribute we
63
- // need to add this function into SYCL kernels array to show it
64
- // this function again.
65
67
SemaRef.AddSyclKernel (Def);
66
68
}
67
69
}
68
70
}
69
71
return true ;
70
72
}
71
73
74
+ bool VisitCXXConstructExpr (CXXConstructExpr *E) {
75
+ for (const auto &Arg : E->arguments ())
76
+ CheckTypeForVirtual (Arg->getType (), Arg->getSourceRange ());
77
+
78
+ CXXConstructorDecl *Ctor = E->getConstructor ();
79
+
80
+ if (FunctionDecl *Def = Ctor->getDefinition ()) {
81
+ Def->addAttr (SYCLDeviceAttr::CreateImplicit (SemaRef.Context ));
82
+ this ->TraverseStmt (Def->getBody ());
83
+ SemaRef.AddSyclKernel (Def);
84
+ }
85
+
86
+ const auto *ConstructedType = Ctor->getParent ();
87
+ if (ConstructedType->hasUserDeclaredDestructor ()) {
88
+ CXXDestructorDecl *Dtor = ConstructedType->getDestructor ();
89
+
90
+ if (FunctionDecl *Def = Dtor->getDefinition ()) {
91
+ Def->addAttr (SYCLDeviceAttr::CreateImplicit (SemaRef.Context ));
92
+ this ->TraverseStmt (Def->getBody ());
93
+ SemaRef.AddSyclKernel (Def);
94
+ }
95
+ }
96
+ return true ;
97
+ }
98
+
99
+ bool VisitTypedefNameDecl (TypedefNameDecl *TD) {
100
+ CheckTypeForVirtual (TD->getUnderlyingType (), TD->getLocation ());
101
+ return true ;
102
+ }
103
+
104
+ bool VisitRecordDecl (RecordDecl *RD) {
105
+ CheckTypeForVirtual (QualType{RD->getTypeForDecl (), 0 }, RD->getLocation ());
106
+ return true ;
107
+ }
108
+
109
+ bool VisitParmVarDecl (VarDecl *VD) {
110
+ CheckTypeForVirtual (VD->getType (), VD->getLocation ());
111
+ return true ;
112
+ }
113
+
114
+ bool VisitVarDecl (VarDecl *VD) {
115
+ CheckTypeForVirtual (VD->getType (), VD->getLocation ());
116
+ return true ;
117
+ }
118
+
119
+ bool VisitDeclRefExpr (DeclRefExpr *E) {
120
+ CheckTypeForVirtual (E->getType (), E->getSourceRange ());
121
+ return true ;
122
+ }
123
+
72
124
private:
125
+ bool CheckTypeForVirtual (QualType Ty, SourceRange Loc) {
126
+ while (Ty->isAnyPointerType () || Ty->isArrayType ())
127
+ Ty = QualType{Ty->getPointeeOrArrayElementType (), 0 };
128
+
129
+ if (const auto *CRD = Ty->getAsCXXRecordDecl ()) {
130
+ if (CRD->isPolymorphic ()) {
131
+ SemaRef.Diag (CRD->getLocation (), diag::err_sycl_virtual_types);
132
+ SemaRef.Diag (Loc.getBegin (), diag::note_sycl_used_here);
133
+ return false ;
134
+ }
135
+
136
+ for (const auto &Field : CRD->fields ()) {
137
+ if (!CheckTypeForVirtual (Field->getType (), Field->getSourceRange ())) {
138
+ SemaRef.Diag (Loc.getBegin (), diag::note_sycl_used_here);
139
+ return false ;
140
+ }
141
+ }
142
+ } else if (const auto *RD = Ty->getAsRecordDecl ()) {
143
+ for (const auto &Field : RD->fields ()) {
144
+ if (!CheckTypeForVirtual (Field->getType (), Field->getSourceRange ())) {
145
+ SemaRef.Diag (Loc.getBegin (), diag::note_sycl_used_here);
146
+ return false ;
147
+ }
148
+ }
149
+ } else if (const auto *FPTy = dyn_cast<FunctionProtoType>(Ty)) {
150
+ for (const auto &ParamTy : FPTy->param_types ())
151
+ if (!CheckTypeForVirtual (ParamTy, Loc))
152
+ return false ;
153
+ return CheckTypeForVirtual (FPTy->getReturnType (), Loc);
154
+ } else if (const auto *FTy = dyn_cast<FunctionType>(Ty)) {
155
+ return CheckTypeForVirtual (FTy->getReturnType (), Loc);
156
+ }
157
+ return true ;
158
+ }
73
159
Sema &SemaRef;
74
160
};
75
161
0 commit comments