Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Upstream changes in lib/Dialect/TritonGPU/IR/Dialect.cpp #2917

Closed
whitneywhtsang opened this issue Dec 4, 2024 · 0 comments · Fixed by #2950
Closed

Upstream changes in lib/Dialect/TritonGPU/IR/Dialect.cpp #2917

whitneywhtsang opened this issue Dec 4, 2024 · 0 comments · Fixed by #2950
Assignees
Labels
enhancement New feature or request upstream: triton

Comments

@whitneywhtsang
Copy link
Contributor

whitneywhtsang commented Dec 4, 2024

--- a/lib/Dialect/TritonGPU/IR/Dialect.cpp
+++ b/lib/Dialect/TritonGPU/IR/Dialect.cpp
@@ -5,6 +5,9 @@
 
 #include "mlir/IR/DialectImplementation.h"
 #include "mlir/IR/OpImplementation.h"
+
+#include "intel/include/Dialect/TritonIntelGPU/IR/Dialect.h"
+
 #include "mlir/Support/LLVM.h"
 #include "triton/Analysis/Utility.h"
 #include "triton/Dialect/Triton/IR/Utility.h"
@@ -208,12 +211,15 @@ SmallVector<unsigned> getShapePerCTATile(Attribute layout) {
           mlir::dyn_cast<DistributedEncodingTrait>(layout)) {
     auto sizePerThread = distributedLayout.getSizePerThread();
     auto threadsPerWarp = distributedLayout.getThreadsPerWarp();
-    // ThreadsPerWarp does not align with this function for slice layout
+    auto warpsPerCTA = distributedLayout.getWarpsPerCTA();
+    // ThreadsPerWarp and warpsPerCTA does not align with this function for
+    // slice layout
     if (auto sliceLayout = mlir::dyn_cast<SliceEncodingAttr>(layout)) {
       threadsPerWarp = getThreadsPerWarp(sliceLayout.getParent());
       threadsPerWarp.erase(threadsPerWarp.begin() + sliceLayout.getDim());
+      warpsPerCTA = getWarpsPerCTA(sliceLayout.getParent());
+      warpsPerCTA.erase(warpsPerCTA.begin() + sliceLayout.getDim());
     }
-    auto warpsPerCTA = distributedLayout.getWarpsPerCTA();
     assert(sizePerThread.size() == threadsPerWarp.size() &&
            sizePerThread.size() == warpsPerCTA.size());
     SmallVector<unsigned> shape;
@@ -305,6 +311,16 @@ SmallVector<unsigned> getOrder(Attribute layout) {
   }
   if (auto dotLayout = dyn_cast<DotOperandEncodingAttr>(layout)) {
     auto rank = dotLayout.getWarpsPerCTA().size();
+    // FIXME: delete if branch for `DpasEncodingAttr` and provide more
+    // general solution to make `getOrderForDotOperand` function compatible
+    // with Intel layouts.
+    // More details:
+    // https://github.com/intel/intel-xpu-backend-for-triton/pull/2517
+    if (dyn_cast<intel::DpasEncodingAttr>(dotLayout.getParent())) {
+      SmallVector<unsigned> order(rank);
+      std::iota(order.rbegin(), order.rend(), 0);
+      return order;
+    }
     return getOrderForDotOperand(dotLayout.getOpIdx(), rank, /*kMajor*/ true);
   }
   if (auto sliceLayout = dyn_cast<SliceEncodingAttr>(layout)) {
@@ -1165,8 +1188,17 @@ SmallVector<unsigned> DotOperandEncodingAttr::getWarpOrder() const {
   return {};
 }
 SmallVector<unsigned> DotOperandEncodingAttr::getThreadOrder() const {
-  return getOrderForDotOperand(getOpIdx(), getWarpsPerCTA().size(),
-                               /*kMajor*/ true);
+  // FIXME: delete if branch for `DpasEncodingAttr` and provide more
+  // general solution to make `getOrderForDotOperand` function compatible
+  // with Intel layouts.
+  // More details:
+  // https://github.com/intel/intel-xpu-backend-for-triton/pull/2517
+  if (mlir::dyn_cast<intel::DpasEncodingAttr>(getParent())) {
+    return ::getOrder(*this);
+  } else {
+    return getOrderForDotOperand(getOpIdx(), getWarpsPerCTA().size(),
+                                 /*kMajor*/ true);
+  }
 }
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request upstream: triton
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants