- 
                Notifications
    You must be signed in to change notification settings 
- Fork 15k
[RISCV] Refactor performCONCAT_VECTORSCombine. NFC #69068
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Instead of doing a forward pass for positive strides and a reverse pass for negative strides, we can just do one pass by negating the offset if the pointers do happen to be in reverse order. We can extend getPtrDiff later in llvm#68726 to handle more constant offset sequences.
| @llvm/pr-subscribers-backend-risc-v Author: Luke Lau (lukel97) ChangesInstead of doing a forward pass for positive strides and a reverse pass for We can extend getPtrDiff later in #68726 to handle more constant offset Full diff: https://github.com/llvm/llvm-project/pull/69068.diff 1 Files Affected: 
 diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index d7552317fd8bc69..9912f19c9a50191 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -13785,11 +13785,10 @@ static SDValue performCONCAT_VECTORSCombine(SDNode *N, SelectionDAG &DAG,
     return SDValue();
 
   EVT BaseLdVT = BaseLd->getValueType(0);
-  SDValue BasePtr = BaseLd->getBasePtr();
 
   // Go through the loads and check that they're strided
-  SmallVector<SDValue> Ptrs;
-  Ptrs.push_back(BasePtr);
+  SmallVector<LoadSDNode *> Lds;
+  Lds.push_back(BaseLd);
   Align Align = BaseLd->getAlign();
   for (SDValue Op : N->ops().drop_front()) {
     auto *Ld = dyn_cast<LoadSDNode>(Op);
@@ -13798,58 +13797,33 @@ static SDValue performCONCAT_VECTORSCombine(SDNode *N, SelectionDAG &DAG,
         Ld->getValueType(0) != BaseLdVT)
       return SDValue();
 
-    Ptrs.push_back(Ld->getBasePtr());
+    Lds.push_back(Ld);
 
     // The common alignment is the most restrictive (smallest) of all the loads
     Align = std::min(Align, Ld->getAlign());
   }
 
-  auto matchForwardStrided = [](ArrayRef<SDValue> Ptrs) {
-    SDValue Stride;
-    for (auto Idx : enumerate(Ptrs)) {
-      if (Idx.index() == 0)
-        continue;
-      SDValue Ptr = Idx.value();
-      // Check that each load's pointer is (add LastPtr, Stride)
-      if (Ptr.getOpcode() != ISD::ADD ||
-          Ptr.getOperand(0) != Ptrs[Idx.index()-1])
-        return SDValue();
-      SDValue Offset = Ptr.getOperand(1);
-      if (!Stride)
-        Stride = Offset;
-      else if (Offset != Stride)
-        return SDValue();
-    }
-    return Stride;
-  };
-  auto matchReverseStrided = [](ArrayRef<SDValue> Ptrs) {
-    SDValue Stride;
-    for (auto Idx : enumerate(Ptrs)) {
-      if (Idx.index() == Ptrs.size() - 1)
-        continue;
-      SDValue Ptr = Idx.value();
-      // Check that each load's pointer is (add NextPtr, Stride)
-      if (Ptr.getOpcode() != ISD::ADD ||
-          Ptr.getOperand(0) != Ptrs[Idx.index()+1])
-        return SDValue();
-      SDValue Offset = Ptr.getOperand(1);
-      if (!Stride)
-        Stride = Offset;
-      else if (Offset != Stride)
-        return SDValue();
-    }
-    return Stride;
+  auto getPtrDiff = [&DAG, &DL](LoadSDNode *Ld1, LoadSDNode *Ld2) {
+    SDValue P1 = Ld1->getBasePtr();
+    SDValue P2 = Ld2->getBasePtr();
+    if (P2.getOpcode() == ISD::ADD && P2.getOperand(0) == P1)
+      return P2.getOperand(1);
+    if (P1.getOpcode() == ISD::ADD && P1.getOperand(0) == P2)
+      return DAG.getNegative(P1.getOperand(1), DL,
+                             P1.getOperand(1).getValueType());
+    return SDValue();
   };
 
-  bool Reversed = false;
-  SDValue Stride = matchForwardStrided(Ptrs);
-  if (!Stride) {
-    Stride = matchReverseStrided(Ptrs);
-    Reversed = true;
-    // TODO: At this point, we've successfully matched a generalized gather
-    // load.  Maybe we should emit that, and then move the specialized
-    // matchers above and below into a DAG combine?
+  SDValue Stride;
+  for (auto [Idx, Ld] : enumerate(Lds)) {
+    if (Idx == 0)
+      continue;
+    SDValue Offset = getPtrDiff(Lds[Idx - 1], Ld);
+    if (!Offset)
+      return SDValue();
     if (!Stride)
+      Stride = Offset;
+    else if (Offset != Stride)
       return SDValue();
   }
 
@@ -13871,22 +13845,17 @@ static SDValue performCONCAT_VECTORSCombine(SDNode *N, SelectionDAG &DAG,
   SDValue IntID =
     DAG.getTargetConstant(Intrinsic::riscv_masked_strided_load, DL,
                           Subtarget.getXLenVT());
-  if (Reversed)
-    Stride = DAG.getNegative(Stride, DL, Stride->getValueType(0));
+
   SDValue AllOneMask =
     DAG.getSplat(WideVecVT.changeVectorElementType(MVT::i1), DL,
                  DAG.getConstant(1, DL, MVT::i1));
 
-  SDValue Ops[] = {BaseLd->getChain(),
-                   IntID,
-                   DAG.getUNDEF(WideVecVT),
-                   BasePtr,
-                   Stride,
-                   AllOneMask};
+  SDValue Ops[] = {BaseLd->getChain(),   IntID,  DAG.getUNDEF(WideVecVT),
+                   BaseLd->getBasePtr(), Stride, AllOneMask};
 
   uint64_t MemSize;
   if (auto *ConstStride = dyn_cast<ConstantSDNode>(Stride);
-      ConstStride && !Reversed && ConstStride->getSExtValue() >= 0)
+      ConstStride && ConstStride->getSExtValue() >= 0)
     // total size = (elsize * n) + (stride - elsize) * (n-1)
     //            = elsize + stride * (n-1)
     MemSize = WideScalarVT.getSizeInBits() +
 | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for making this change! LGTM.
| if (P2.getOpcode() == ISD::ADD && P2.getOperand(0) == P1) | ||
| return P2.getOperand(1); | ||
| if (P1.getOpcode() == ISD::ADD && P1.getOperand(0) == P2) | ||
| return DAG.getNegative(P1.getOperand(1), DL, | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we want to create this node if the overall match fails?
| if (!TLI.isLegalStridedLoadStore(WideVecVT, Align)) | ||
| return SDValue(); | ||
|  | ||
| auto [Stride, Reversed] = *BaseDiff; | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looking at #68726, I wonder if Reversed may not be the best name for this variable? The RVV spec uses the term negative stride. Additionally, when BaseIndexOffset returns a negative value, it does not need to be negated here although that load is "in reverse". I suggest we rename this to something along the lines of MustNegateStride?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah that's a better name, will change. (Also open to any other ideas on how we can defer creating the negate node)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am content with this approach.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
Instead of doing a forward pass for positive strides and a reverse pass for
negative strides, we can just do one pass by negating the offset if the
pointers do happen to be in reverse order.
We can extend getPtrDiff later in #68726 to handle more constant offset
sequences.