Skip to content

[NVPTX] Mark callseq insts as reading and writing memory #151376

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 38 additions & 32 deletions llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
Original file line number Diff line number Diff line change
Expand Up @@ -1830,6 +1830,18 @@ def : Pat<(declare_array_param externalsym:$a, imm:$align, imm:$size),
def : Pat<(declare_scalar_param externalsym:$a, imm:$size),
(DECLARE_PARAM_scalar (to_texternsym $a), imm:$size)>;

// Call prototype wrapper, this is a dummy instruction that just prints it's
// operand which is string defining the prototype.
def SDTCallPrototype : SDTypeProfile<0, 1, [SDTCisInt<0>]>;
def CallPrototype :
SDNode<"NVPTXISD::CallPrototype", SDTCallPrototype,
[SDNPHasChain, SDNPOutGlue, SDNPInGlue, SDNPSideEffect]>;
def ProtoIdent : Operand<i32> { let PrintMethod = "printProtoIdent"; }
def CALL_PROTOTYPE :
NVPTXInst<(outs), (ins ProtoIdent:$ident),
"$ident", [(CallPrototype (i32 texternalsym:$ident))]>;


foreach t = [I32RT, I64RT] in {
defvar inst_name = "MOV" # t.Size # "_PARAM";
def inst_name : BasicNVPTXInst<(outs t.RC:$dst), (ins t.RC:$src), "mov.b" # t.Size>;
Expand All @@ -1849,6 +1861,32 @@ defm ProxyRegB16 : ProxyRegInst<"b16", B16>;
defm ProxyRegB32 : ProxyRegInst<"b32", B32>;
defm ProxyRegB64 : ProxyRegInst<"b64", B64>;


// Callseq start and end

// Note: these nodes are marked as SDNPMayStore and SDNPMayLoad because
// they define the scope in which the declared params may be used. Therefore
// we add these flags to ensure ld.param and st.param are not sunk or hoisted
// out of that scope.

def callseq_start : SDNode<"ISD::CALLSEQ_START",
SDCallSeqStart<[SDTCisVT<0, i32>, SDTCisVT<1, i32>]>,
[SDNPHasChain, SDNPOutGlue,
SDNPSideEffect, SDNPMayStore, SDNPMayLoad]>;
def callseq_end : SDNode<"ISD::CALLSEQ_END",
SDCallSeqEnd<[SDTCisVT<0, i32>, SDTCisVT<1, i32>]>,
[SDNPHasChain, SDNPOptInGlue, SDNPOutGlue,
SDNPSideEffect, SDNPMayStore, SDNPMayLoad]>;

def Callseq_Start :
NVPTXInst<(outs), (ins i32imm:$amt1, i32imm:$amt2),
"\\{ // callseq $amt1, $amt2",
[(callseq_start timm:$amt1, timm:$amt2)]>;
def Callseq_End :
NVPTXInst<(outs), (ins i32imm:$amt1, i32imm:$amt2),
"\\} // callseq $amt1",
[(callseq_end timm:$amt1, timm:$amt2)]>;

//
// Load / Store Handling
//
Expand Down Expand Up @@ -2392,26 +2430,6 @@ def : Pat<(brcond i32:$a, bb:$target),
def : Pat<(brcond (i1 (setne i1:$a, -1)), bb:$target),
(CBranchOther $a, bb:$target)>;

// Call
def SDT_NVPTXCallSeqStart : SDCallSeqStart<[SDTCisVT<0, i32>,
SDTCisVT<1, i32>]>;
def SDT_NVPTXCallSeqEnd : SDCallSeqEnd<[SDTCisVT<0, i32>, SDTCisVT<1, i32>]>;

def callseq_start : SDNode<"ISD::CALLSEQ_START", SDT_NVPTXCallSeqStart,
[SDNPHasChain, SDNPOutGlue, SDNPSideEffect]>;
def callseq_end : SDNode<"ISD::CALLSEQ_END", SDT_NVPTXCallSeqEnd,
[SDNPHasChain, SDNPOptInGlue, SDNPOutGlue,
SDNPSideEffect]>;

def Callseq_Start :
NVPTXInst<(outs), (ins i32imm:$amt1, i32imm:$amt2),
"\\{ // callseq $amt1, $amt2",
[(callseq_start timm:$amt1, timm:$amt2)]>;
def Callseq_End :
NVPTXInst<(outs), (ins i32imm:$amt1, i32imm:$amt2),
"\\} // callseq $amt1",
[(callseq_end timm:$amt1, timm:$amt2)]>;

// trap instruction
def trapinst : BasicNVPTXInst<(outs), (ins), "trap", [(trap)]>, Requires<[noPTXASUnreachableBug]>;
// Emit an `exit` as well to convey to ptxas that `trap` exits the CFG.
Expand All @@ -2420,18 +2438,6 @@ def trapexitinst : NVPTXInst<(outs), (ins), "trap; exit;", [(trap)]>, Requires<[
// brkpt instruction
def debugtrapinst : BasicNVPTXInst<(outs), (ins), "brkpt", [(debugtrap)]>;

// Call prototype wrapper
def SDTCallPrototype : SDTypeProfile<0, 1, [SDTCisInt<0>]>;
def CallPrototype :
SDNode<"NVPTXISD::CallPrototype", SDTCallPrototype,
[SDNPHasChain, SDNPOutGlue, SDNPInGlue, SDNPSideEffect]>;
def ProtoIdent : Operand<i32> {
let PrintMethod = "printProtoIdent";
}
def CALL_PROTOTYPE :
NVPTXInst<(outs), (ins ProtoIdent:$ident),
"$ident", [(CallPrototype (i32 texternalsym:$ident))]>;

def SDTDynAllocaOp :
SDTypeProfile<1, 2, [SDTCisSameAs<0, 1>, SDTCisInt<1>, SDTCisVT<2, i32>]>;

Expand Down
47 changes: 47 additions & 0 deletions llvm/test/CodeGen/NVPTX/ld-param-sink.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
; RUN: llc < %s -verify-machineinstrs | FileCheck %s
; RUN: %if ptxas %{ llc < %s | %ptxas-verify %}

target triple = "nvptx64-nvidia-cuda"

declare ptr @bar(i64)
declare i64 @baz()

define ptr @foo(i1 %cond) {
; CHECK-LABEL: foo(
; CHECK: {
; CHECK-NEXT: .reg .pred %p<2>;
; CHECK-NEXT: .reg .b16 %rs<3>;
; CHECK-NEXT: .reg .b64 %rd<3>;
; CHECK-EMPTY:
; CHECK-NEXT: // %bb.0: // %entry
; CHECK-NEXT: ld.param.b8 %rs1, [foo_param_0];
; CHECK-NEXT: and.b16 %rs2, %rs1, 1;
; CHECK-NEXT: setp.ne.b16 %p1, %rs2, 0;
; CHECK-NEXT: { // callseq 0, 0
; CHECK-NEXT: .param .b64 retval0;
; CHECK-NEXT: call.uni (retval0), baz, ();
; CHECK-NEXT: ld.param.b64 %rd2, [retval0];
; CHECK-NEXT: } // callseq 0
; CHECK-NEXT: @%p1 bra $L__BB0_2;
; CHECK-NEXT: // %bb.1: // %bb
; CHECK-NEXT: { // callseq 1, 0
; CHECK-NEXT: .param .b64 param0;
; CHECK-NEXT: .param .b64 retval0;
; CHECK-NEXT: st.param.b64 [param0], %rd2;
; CHECK-NEXT: call.uni (retval0), bar, (param0);
; CHECK-NEXT: } // callseq 1
; CHECK-NEXT: $L__BB0_2: // %common.ret
; CHECK-NEXT: st.param.b64 [func_retval0], 0;
; CHECK-NEXT: ret;
entry:
%call = call i64 @baz()
br i1 %cond, label %common.ret, label %bb

bb:
%tmp = call ptr @bar(i64 %call)
br label %common.ret

common.ret:
ret ptr null
}