Skip to content

Commit

Permalink
Remove Intel code
Browse files Browse the repository at this point in the history
  • Loading branch information
chengjunlu committed Dec 17, 2024
1 parent f1a893a commit 699a84f
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 28 deletions.
54 changes: 27 additions & 27 deletions lib/Dialect/TritonGPU/IR/Dialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -309,11 +309,11 @@ SmallVector<unsigned> getOrder(Attribute layout) {
// with Intel layouts.
// More details:
// https://github.com/intel/intel-xpu-backend-for-triton/pull/2517
if (dyn_cast<intel::DpasEncodingAttr>(dotLayout.getParent())) {
SmallVector<unsigned> order(rank);
std::iota(order.rbegin(), order.rend(), 0);
return order;
}
// if (dyn_cast<intel::DpasEncodingAttr>(dotLayout.getParent())) {
// SmallVector<unsigned> order(rank);
// std::iota(order.rbegin(), order.rend(), 0);
// return order;
// }
return getOrderForDotOperand(dotLayout.getOpIdx(), rank, /*kMajor*/ true);
}
if (auto sliceLayout = dyn_cast<SliceEncodingAttr>(layout)) {
Expand Down Expand Up @@ -1093,10 +1093,11 @@ unsigned DotOperandEncodingAttr::getTotalElemsPerThread(ArrayRef<int64_t> shape,
return amdWmmaParent.getTotalElemsPerThreadForOperand(
shape, eltTy, getKWidth(), getOpIdx());
}
if (auto dpasParent = mlir::dyn_cast<intel::DpasEncodingAttr>(mmaParent)) {
return dpasParent.getTotalElemsPerThreadForOperand(
shape, eltTy, getKWidth(), getOpIdx());
}
// if (auto dpasParent =
// mlir::dyn_cast<intel::DpasEncodingAttr>(mmaParent)) {
// return dpasParent.getTotalElemsPerThreadForOperand(
// shape, eltTy, getKWidth(), getOpIdx());
// }
}
if (auto blockedLayout = mlir::dyn_cast<BlockedEncodingAttr>(getParent())) {
auto shapePerCTA = getShapePerCTA(*this, shape);
Expand Down Expand Up @@ -1161,17 +1162,19 @@ SmallVector<unsigned> DotOperandEncodingAttr::getWarpOrder() const {
return {};
}
SmallVector<unsigned> DotOperandEncodingAttr::getThreadOrder() const {
// FIXME: delete if branch for `DpasEncodingAttr` and provide more
// general solution to make `getOrderForDotOperand` function compatible
// with Intel layouts.
// More details:
// https://github.com/intel/intel-xpu-backend-for-triton/pull/2517
if (mlir::dyn_cast<intel::DpasEncodingAttr>(getParent())) {
return ::getOrder(*this);
} else {
return getOrderForDotOperand(getOpIdx(), getWarpsPerCTA().size(),
/*kMajor*/ true);
}
// // FIXME: delete if branch for `DpasEncodingAttr` and provide more
// // general solution to make `getOrderForDotOperand` function compatible
// // with Intel layouts.
// // More details:
// // https://github.com/intel/intel-xpu-backend-for-triton/pull/2517
// if (mlir::dyn_cast<intel::DpasEncodingAttr>(getParent())) {
// return ::getOrder(*this);
// } else {
// return getOrderForDotOperand(getOpIdx(), getWarpsPerCTA().size(),
// /*kMajor*/ true);
// }
return getOrderForDotOperand(getOpIdx(), getWarpsPerCTA().size(),
/*kMajor*/ true);
}

LogicalResult DotOperandEncodingAttr::verify(
Expand Down Expand Up @@ -1214,20 +1217,17 @@ LogicalResult DotOperandEncodingAttr::verify(
return success();
}

if (auto parentAttr = mlir::dyn_cast<intel::DpasEncodingAttr>(parent)) {
if (kWidth != parentAttr.getOpsPerChannel())
return emitError() << "ttg.dot_op kWidth parameter must match the "
"parent's opsPerChannel";
return success();
}

if (auto parentAttr = mlir::dyn_cast<intel::WarpEncodingAttr>(parent)) {
if (kWidth != 0)
return emitError() << "ttg.dot_op kWidth parameter is not supported "
"when the parent is a warp layout";
return success();
}

if (auto parentAttr = mlir::dyn_cast<MmaEncodingTrait>(parent)) {
return success();
}

if (auto parentAttr = mlir::dyn_cast<BlockedEncodingAttr>(parent)) {
if (kWidth != 0)
return emitError() << "ttg.dot_op kWidth parameter is not supported "
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -515,7 +515,12 @@ struct LoadOpConversion
const bool memoryRowMajor = (memoryLayoutInfo == "row_major");

DotOperandEncodingAttr dotLayout = getDotEncoding(tensorType).value();
auto dotOrder = dotLayout.getThreadOrder();
std::optional<LinearLayout> dotLL =
dotLayout.toLinearLayout(tensorType.getShape());
assert(dotLL.has_value() && "invalid dot layout to linear layout");
LinearEncodingAttr dotLLAttr =
LinearEncodingAttr::get(rewriter.getContext(), *dotLL);
SmallVector<unsigned> dotOrder = dotLLAttr.getThreadOrder();
size_t rank = dotOrder.size();
const bool valueRowMajor =
(dotOrder[rank - 2] == 1 && dotOrder[rank - 1] == 0);
Expand Down

0 comments on commit 699a84f

Please sign in to comment.