Skip to content

Commit 84d2ea2

Browse files
authored
Opening Hat Dialect for Handling Barriers and Shared/Private Memory (#560)
1 parent 5ddaa81 commit 84d2ea2

36 files changed

+1920
-159
lines changed

hat/backends/ffi/cuda/src/main/java/hat/backend/ffi/PTXHATKernelBuilder.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
*/
2525
package hat.backend.ffi;
2626

27+
import hat.dialect.HatOP;
2728
import hat.ifacemapper.BoundSchema;
2829
import hat.optools.*;
2930
import hat.codebuilders.CodeBuilder;
@@ -163,7 +164,7 @@ public void functionEpilogue() {
163164
cbrace();
164165
}
165166

166-
public static class PTXPtrOp extends Op {
167+
public static class PTXPtrOp extends HatOP {
167168
public String fieldName;
168169
public static final String NAME = "ptxPtr";
169170
final TypeElement resultType;

hat/backends/ffi/cuda/src/main/native/cpp/cuda_backend_queue.cpp

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -178,13 +178,25 @@ void CudaBackend::CudaQueue::dispatch(KernelContext *kernelContext, CompilationU
178178
std::cout << "dispatch() thread=" <<thread_id<< " != "<< streamCreationThread<< std::endl;
179179
}
180180

181+
// // CUDA events for timing
182+
// cudaEvent_t start, stop;
183+
// cuEventCreate(&start, cudaEventDefault);
184+
// cuEventCreate(&stop, cudaEventDefault);
185+
// cuEventRecord(start, 0);
186+
181187
const auto status = cuLaunchKernel(cudaKernel->function, //
182188
blocksPerGridX, blocksPerGridY, blocksPerGridZ, //
183189
threadsPerBlockX, threadsPerBlockY, threadsPerBlockZ, //
184190
0, //
185191
cuStream, //
186192
cudaKernel->argslist, //
187193
nullptr);
194+
// cuEventRecord(stop, 0);
195+
// cuEventSynchronize(stop); // Wait for completion
196+
//
197+
// float elapsedTimeMs = 0.0f;
198+
// cuEventElapsedTime(&elapsedTimeMs, start, stop);
199+
// std::cout << "Kernel Elapsed Time: " << elapsedTimeMs << " ms\n";
188200

189201
CUDA_CHECK(status, "cuLaunchKernel");
190202
}

hat/backends/ffi/shared/src/main/java/hat/backend/ffi/C99FFIBackend.java

Lines changed: 11 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -36,15 +36,13 @@
3636
import hat.buffer.KernelContext;
3737
import hat.callgraph.KernelCallGraph;
3838
import hat.codebuilders.ScopedCodeBuilderContext;
39+
import hat.dialect.HatMemoryOp;
3940
import hat.ifacemapper.BoundSchema;
4041
import hat.ifacemapper.BufferState;
4142
import hat.ifacemapper.Schema;
4243
import hat.optools.OpTk;
4344
import jdk.incubator.code.CopyContext;
4445
import jdk.incubator.code.Op;
45-
import jdk.incubator.code.dialect.java.ClassType;
46-
import jdk.incubator.code.dialect.java.JavaOp;
47-
import jdk.incubator.code.dialect.java.JavaType;
4846

4947
import java.lang.invoke.MethodHandles;
5048
import java.lang.reflect.InvocationTargetException;
@@ -142,16 +140,15 @@ public void dispatch(NDRange ndRange, Object[] args) {
142140

143141
public Map<KernelCallGraph, CompiledKernel> kernelCallGraphCompiledCodeMap = new HashMap<>();
144142

145-
private void updateListOfSchemas(Op op, MethodHandles.Lookup lookup, List<String> localIfaceList) {
146-
if (Objects.requireNonNull(op) instanceof JavaOp.InvokeOp invokeOp) {
147-
if (OpTk.isIfaceAccessor(lookup, invokeOp)) {
148-
String klassName = invokeOp.resultType().toString();
149-
localIfaceList.add(klassName);
150-
}
143+
private void updateListOfSchemas(Op op, List<String> localIfaceList) {
144+
if (Objects.requireNonNull(op) instanceof HatMemoryOp hatMemoryOp) {
145+
String klassName = hatMemoryOp.invokeType().toString();
146+
localIfaceList.add(klassName);
151147
}
152148
}
153149

154150
public <T extends C99HATKernelBuilder<T>> String createCode(KernelCallGraph kernelCallGraph, T builder, Object... args) {
151+
155152
builder.defines().pragmas().types();
156153
Set<Schema.IfaceType> already = new LinkedHashSet<>();
157154
Arrays.stream(args)
@@ -171,9 +168,9 @@ public <T extends C99HATKernelBuilder<T>> String createCode(KernelCallGraph kern
171168
// Traverse the list of reachable functions and append the intrinsics functions found for each of the functions
172169
if (kernelCallGraph.moduleOp != null) {
173170
kernelCallGraph.moduleOp.functionTable()
174-
.forEach((_, funcOp) -> {
175-
funcOp.transform(CopyContext.create(), (blockBuilder, op) -> {
176-
updateListOfSchemas(op, kernelCallGraph.computeContext.accelerator.lookup, localIFaceList);
171+
.forEach((entryName, f) -> {
172+
f.transform(CopyContext.create(), (blockBuilder, op) -> {
173+
updateListOfSchemas(op, localIFaceList);
177174
blockBuilder.op(op);
178175
return blockBuilder;
179176
});
@@ -183,7 +180,7 @@ public <T extends C99HATKernelBuilder<T>> String createCode(KernelCallGraph kern
183180
// this else-branch will be deleted.
184181
kernelCallGraph.kernelReachableResolvedStream().forEach((kernel) -> {
185182
kernel.funcOp().transform(CopyContext.create(), (blockBuilder, op) -> {
186-
updateListOfSchemas(op, kernelCallGraph.computeContext.accelerator.lookup, localIFaceList);
183+
updateListOfSchemas(op, localIFaceList);
187184
blockBuilder.op(op);
188185
return blockBuilder;
189186
});
@@ -193,7 +190,7 @@ public <T extends C99HATKernelBuilder<T>> String createCode(KernelCallGraph kern
193190
// Traverse the main kernel and append the intrinsics functions found in the main kernel
194191
kernelCallGraph.entrypoint.funcOp()
195192
.transform(CopyContext.create(), (blockBuilder, op) -> {
196-
updateListOfSchemas(op, kernelCallGraph.computeContext.accelerator.lookup, localIFaceList);
193+
updateListOfSchemas(op, localIFaceList);
197194
blockBuilder.op(op);
198195
return blockBuilder;
199196
});

hat/core/src/main/java/hat/ComputeContext.java

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,17 +29,22 @@
2929
import hat.buffer.BufferTracker;
3030
import hat.callgraph.ComputeCallGraph;
3131
import hat.callgraph.KernelCallGraph;
32+
import hat.dialect.HatBarrierOp;
3233
import hat.ifacemapper.BoundSchema;
3334
import hat.ifacemapper.SegmentMapper;
3435
import hat.optools.OpTk;
36+
import jdk.incubator.code.Block;
37+
import jdk.incubator.code.CopyContext;
3538
import jdk.incubator.code.Op;
3639
import jdk.incubator.code.Quotable;
3740
import jdk.incubator.code.Quoted;
41+
import jdk.incubator.code.Value;
3842
import jdk.incubator.code.dialect.core.CoreOp;
3943
import jdk.incubator.code.dialect.java.JavaOp;
4044
import jdk.incubator.code.dialect.java.MethodRef;
4145

4246
import java.lang.reflect.Method;
47+
import java.util.List;
4348
import java.util.function.Consumer;
4449

4550
/**
@@ -132,13 +137,39 @@ public void dispatchKernel(ComputeRange computeRange, QuotableKernelContextConsu
132137
dispatchKernelWithComputeRange(computeRange, quotableKernelContextConsumer);
133138
}
134139

140+
private boolean isMethodFromHatKernelContext(JavaOp.InvokeOp invokeOp) {
141+
String kernelContextCanonicalName = hat.KernelContext.class.getName();
142+
return invokeOp.invokeDescriptor().refType().toString().equals(kernelContextCanonicalName);
143+
}
144+
145+
private boolean isMethod(JavaOp.InvokeOp invokeOp, String methodName) {
146+
return invokeOp.invokeDescriptor().name().equals(methodName);
147+
}
148+
149+
private void createBarrierNodeOp(CopyContext context, JavaOp.InvokeOp invokeOp, Block.Builder blockBuilder) {
150+
List<Value> inputOperands = invokeOp.operands();
151+
List<Value> outputOperands = context.getValues(inputOperands);
152+
HatBarrierOp hatBarrierOp = new HatBarrierOp(outputOperands);
153+
Op.Result outputResult = blockBuilder.op(hatBarrierOp);
154+
Op.Result inputResult = invokeOp.result();
155+
context.mapValue(inputResult, outputResult);
156+
}
157+
135158
record CallGraph(Quoted quoted, JavaOp.LambdaOp lambdaOp, MethodRef methodRef, KernelCallGraph kernelCallGraph) {}
136159

137160
private CallGraph buildKernelCallGraph(QuotableKernelContextConsumer quotableKernelContextConsumer) {
138161
Quoted quoted = Op.ofQuotable(quotableKernelContextConsumer).orElseThrow();
139162
JavaOp.LambdaOp lambdaOp = (JavaOp.LambdaOp) quoted.op();
140-
MethodRef methodRef =OpTk.getQuotableTargetInvokeOpWrapper( lambdaOp).invokeDescriptor();
163+
MethodRef methodRef = OpTk.getQuotableTargetInvokeOpWrapper( lambdaOp).invokeDescriptor();
141164
KernelCallGraph kernelCallGraph = computeCallGraph.kernelCallGraphMap.get(methodRef);
165+
// Analysis : dialect
166+
// NOTE: Keep the following boolean until we have the config available/reachable
167+
// from this class
168+
boolean useDialect = true;
169+
if (useDialect) {
170+
//System.out.println("[INFO] Using Hat Dialect?: " + useDialect);
171+
kernelCallGraph.dialectifyToHat();
172+
}
142173
return new CallGraph(quoted, lambdaOp, methodRef, kernelCallGraph);
143174
}
144175

@@ -157,6 +188,7 @@ private void dispatchKernel(int rangeX, int rangeY, int rangeZ, int dimNumber, Q
157188
accelerator.backend.dispatchKernel(cg.kernelCallGraph, ndRange, args);
158189
} catch (Throwable t) {
159190
System.out.print("what?" + cg.methodRef + " " + t);
191+
t.printStackTrace();
160192
throw t;
161193
}
162194
}
@@ -170,6 +202,7 @@ private void dispatchKernelWithComputeRange(ComputeRange computeRange, QuotableK
170202
accelerator.backend.dispatchKernel(cg.kernelCallGraph, ndRange, args);
171203
} catch (Throwable t) {
172204
System.out.print("what?" + cg.methodRef + " " + t);
205+
t.printStackTrace();
173206
throw t;
174207
}
175208
}

hat/core/src/main/java/hat/buffer/Buffer.java

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@
2525
package hat.buffer;
2626

2727

28-
import hat.Space;
2928
import hat.ifacemapper.BoundSchema;
3029
import hat.ifacemapper.BufferState;
3130
import hat.ifacemapper.MappableIface;

hat/core/src/main/java/hat/callgraph/ComputeCallGraph.java

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -229,9 +229,11 @@ public void closeWithModuleOp(ComputeReachableResolvedMethodCall computeReachabl
229229

230230
@Override
231231
public boolean filterCalls(CoreOp.FuncOp f, JavaOp.InvokeOp invokeOp, Method method, MethodRef methodRef, Class<?> javaRefTypeClass) {
232-
if (entrypoint.method.getDeclaringClass().equals(OpTk.javaRefClassOrThrow(computeContext.accelerator.lookup,invokeOp)) && isKernelDispatch(computeContext.accelerator.lookup,method, f)) {
232+
if (entrypoint.method.getDeclaringClass().equals(OpTk.javaRefClassOrThrow(computeContext.accelerator.lookup,invokeOp))
233+
&& isKernelDispatch(computeContext.accelerator.lookup,method, f)) {
233234
kernelCallGraphMap.computeIfAbsent(methodRef, _ ->
234-
new KernelCallGraph(this, methodRef, method, f).closeWithModuleOp()
235+
new KernelCallGraph(this, methodRef, method, f)
236+
.closeWithModuleOp()
235237
);
236238
} else if (ComputeContext.class.isAssignableFrom(javaRefTypeClass)) {
237239
computeContextMethodCall = new ComputeContextMethodCall(this, methodRef, method);

0 commit comments

Comments
 (0)