Skip to content

Commit

Permalink
Remove Intel code
Browse files Browse the repository at this point in the history
  • Loading branch information
chengjunlu committed Dec 6, 2024
1 parent b1c3c72 commit 841bc4f
Showing 1 changed file with 28 additions and 29 deletions.
57 changes: 28 additions & 29 deletions lib/Dialect/TritonGPU/IR/Dialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -325,11 +325,11 @@ SmallVector<unsigned> getOrder(Attribute layout) {
// with Intel layouts.
// More details:
// https://github.com/intel/intel-xpu-backend-for-triton/pull/2517
if (dyn_cast<intel::DpasEncodingAttr>(dotLayout.getParent())) {
SmallVector<unsigned> order(rank);
std::iota(order.rbegin(), order.rend(), 0);
return order;
}
// if (dyn_cast<intel::DpasEncodingAttr>(dotLayout.getParent())) {
// SmallVector<unsigned> order(rank);
// std::iota(order.rbegin(), order.rend(), 0);
// return order;
// }
return getOrderForDotOperand(dotLayout.getOpIdx(), rank, /*kMajor*/ true);
}
if (auto sliceLayout = dyn_cast<SliceEncodingAttr>(layout)) {
Expand Down Expand Up @@ -1129,10 +1129,11 @@ unsigned DotOperandEncodingAttr::getTotalElemsPerThread(ArrayRef<int64_t> shape,
return amdWmmaParent.getTotalElemsPerThreadForOperand(
shape, eltTy, getKWidth(), getOpIdx());
}
if (auto dpasParent = mlir::dyn_cast<intel::DpasEncodingAttr>(mmaParent)) {
return dpasParent.getTotalElemsPerThreadForOperand(
shape, eltTy, getKWidth(), getOpIdx());
}
// if (auto dpasParent =
// mlir::dyn_cast<intel::DpasEncodingAttr>(mmaParent)) {
// return dpasParent.getTotalElemsPerThreadForOperand(
// shape, eltTy, getKWidth(), getOpIdx());
// }
}
if (auto blockedLayout = mlir::dyn_cast<BlockedEncodingAttr>(getParent())) {
auto shapePerCTA = getShapePerCTA(*this, shape);
Expand Down Expand Up @@ -1197,17 +1198,19 @@ SmallVector<unsigned> DotOperandEncodingAttr::getWarpOrder() const {
return {};
}
SmallVector<unsigned> DotOperandEncodingAttr::getThreadOrder() const {
// FIXME: delete if branch for `DpasEncodingAttr` and provide more
// general solution to make `getOrderForDotOperand` function compatible
// with Intel layouts.
// More details:
// https://github.com/intel/intel-xpu-backend-for-triton/pull/2517
if (mlir::dyn_cast<intel::DpasEncodingAttr>(getParent())) {
return ::getOrder(*this);
} else {
return getOrderForDotOperand(getOpIdx(), getWarpsPerCTA().size(),
/*kMajor*/ true);
}
// // FIXME: delete if branch for `DpasEncodingAttr` and provide more
// // general solution to make `getOrderForDotOperand` function compatible
// // with Intel layouts.
// // More details:
// // https://github.com/intel/intel-xpu-backend-for-triton/pull/2517
// if (mlir::dyn_cast<intel::DpasEncodingAttr>(getParent())) {
// return ::getOrder(*this);
// } else {
// return getOrderForDotOperand(getOpIdx(), getWarpsPerCTA().size(),
// /*kMajor*/ true);
// }
return getOrderForDotOperand(getOpIdx(), getWarpsPerCTA().size(),
/*kMajor*/ true);
}

LogicalResult DotOperandEncodingAttr::verify(
Expand Down Expand Up @@ -1250,20 +1253,17 @@ LogicalResult DotOperandEncodingAttr::verify(
return success();
}

if (auto parentAttr = mlir::dyn_cast<intel::DpasEncodingAttr>(parent)) {
if (kWidth != parentAttr.getOpsPerChannel())
return emitError() << "ttg.dot_op kWidth parameter must match the "
"parent's opsPerChannel";
return success();
}

if (auto parentAttr = mlir::dyn_cast<intel::WarpEncodingAttr>(parent)) {
if (kWidth != 0)
return emitError() << "ttg.dot_op kWidth parameter is not supported "
"when the parent is a warp layout";
return success();
}

if (auto parentAttr = mlir::dyn_cast<MmaEncodingTrait>(parent)) {
return success();
}

if (auto parentAttr = mlir::dyn_cast<BlockedEncodingAttr>(parent)) {
if (kWidth != 0)
return emitError() << "ttg.dot_op kWidth parameter is not supported "
Expand Down Expand Up @@ -3248,8 +3248,7 @@ struct CanonicalizeConvertFromConvert
auto srcType = op.getSrc().getType();
auto dstType = op.getType();
if (mlir::isa<DotOperandEncodingAttr>(dstType.getEncoding()) &&
(mlir::isa<NvidiaMmaEncodingAttr>(srcType.getEncoding()) ||
mlir::isa<intel::DpasEncodingAttr>(srcType.getEncoding())))
mlir::isa<NvidiaMmaEncodingAttr>(srcType.getEncoding()))
return failure();

// for hopper MMAv3
Expand Down

0 comments on commit 841bc4f

Please sign in to comment.