Skip to content

Commit

Permalink
[CommonCodeClean]Clean changes in common code (#2950)
Browse files Browse the repository at this point in the history
Clean changes in common code
  • Loading branch information
chengjunlu authored Dec 19, 2024
1 parent fdab3bb commit d662e65
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 48 deletions.
27 changes: 2 additions & 25 deletions lib/Dialect/TritonGPU/IR/Dialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -304,16 +304,6 @@ SmallVector<unsigned> getOrder(Attribute layout) {
}
if (auto dotLayout = dyn_cast<DotOperandEncodingAttr>(layout)) {
auto rank = dotLayout.getWarpsPerCTA().size();
// FIXME: delete if branch for `DpasEncodingAttr` and provide more
// general solution to make `getOrderForDotOperand` function compatible
// with Intel layouts.
// More details:
// https://github.com/intel/intel-xpu-backend-for-triton/pull/2517
if (dyn_cast<intel::DpasEncodingAttr>(dotLayout.getParent())) {
SmallVector<unsigned> order(rank);
std::iota(order.rbegin(), order.rend(), 0);
return order;
}
return getOrderForDotOperand(dotLayout.getOpIdx(), rank, /*kMajor*/ true);
}
if (auto sliceLayout = dyn_cast<SliceEncodingAttr>(layout)) {
Expand Down Expand Up @@ -1093,10 +1083,6 @@ unsigned DotOperandEncodingAttr::getTotalElemsPerThread(ArrayRef<int64_t> shape,
return amdWmmaParent.getTotalElemsPerThreadForOperand(
shape, eltTy, getKWidth(), getOpIdx());
}
if (auto dpasParent = mlir::dyn_cast<intel::DpasEncodingAttr>(mmaParent)) {
return dpasParent.getTotalElemsPerThreadForOperand(
shape, eltTy, getKWidth(), getOpIdx());
}
}
if (auto blockedLayout = mlir::dyn_cast<BlockedEncodingAttr>(getParent())) {
auto shapePerCTA = getShapePerCTA(*this, shape);
Expand Down Expand Up @@ -1159,17 +1145,8 @@ SmallVector<unsigned> DotOperandEncodingAttr::getWarpOrder() const {
return {};
}
SmallVector<unsigned> DotOperandEncodingAttr::getThreadOrder() const {
// FIXME: delete if branch for `DpasEncodingAttr` and provide more
// general solution to make `getOrderForDotOperand` function compatible
// with Intel layouts.
// More details:
// https://github.com/intel/intel-xpu-backend-for-triton/pull/2517
if (mlir::dyn_cast<intel::DpasEncodingAttr>(getParent())) {
return ::getOrder(*this);
} else {
return getOrderForDotOperand(getOpIdx(), getWarpsPerCTA().size(),
/*kMajor*/ true);
}
return getOrderForDotOperand(getOpIdx(), getWarpsPerCTA().size(),
/*kMajor*/ true);
}

LogicalResult DotOperandEncodingAttr::verify(
Expand Down
42 changes: 19 additions & 23 deletions third_party/intel/lib/TritonIntelGPUToLLVM/LoadStoreOpToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -526,6 +526,21 @@ struct LoadOpConversion
};
auto opIdx = getOpIdx();

std::optional<LinearLayout> llEncoding =
cast<DistributedEncodingTrait>(encoding).toLinearLayout(
tensorType.getShape());
assert(llEncoding.has_value() && "invalid dot layout to linear layout");
LinearEncodingAttr llAttr =
LinearEncodingAttr::get(rewriter.getContext(), *llEncoding);
SmallVector<unsigned> threadOrder = llAttr.getThreadOrder();
size_t rank = threadOrder.size();
const bool valueRowMajor =
(threadOrder[rank - 2] == 1 && threadOrder[rank - 1] == 0);
assert((valueRowMajor ||
(threadOrder[rank - 2] == 0 && threadOrder[rank - 1] == 1)) &&
"Only row_major or column_major is allowed");
const bool isTransposeRequired = valueRowMajor ^ memoryRowMajor;

Type eltTy = tensorType.getElementType();
unsigned elemSizeInBits = eltTy.getIntOrFloatBitWidth();

Expand All @@ -539,15 +554,15 @@ struct LoadOpConversion
SmallVector<int64_t> numReps =
dpasLayout.getDPASRepetitions(tensorShape, opIdx);
const SmallVector<unsigned> warpsPerCTA = dpasLayout.getWarpsPerCTA();
SmallVector<unsigned> dpasOrder = triton::gpu::getOrder(dpasLayout);
SmallVector<unsigned> dpasWarpsOrder = triton::gpu::getOrder(dpasLayout);
int threadsPerWarp = triton::gpu::getWarpSize(dpasLayout);

Value warpId = rewriter.create<arith::IndexCastOp>(
loc, i32_ty,
rewriter.create<mlir::gpu::SubgroupIdOp>(loc, /*upperBound=*/nullptr));

SmallVector<Value> multiDimWarpId =
delinearize(rewriter, loc, warpId, warpsPerCTA, dpasOrder);
delinearize(rewriter, loc, warpId, warpsPerCTA, dpasWarpsOrder);

if (hasDpasLayout) {
// A block load with the DPAS layout but without the DotDpasLayout is
Expand All @@ -557,14 +572,6 @@ struct LoadOpConversion
// aligns to the DPAS layout as the DPAS operation output layout
// distributes rows across work items.

size_t rank = dpasOrder.size();
const bool valueRowMajor =
(dpasOrder[rank - 2] == 1 && dpasOrder[rank - 1] == 0);
assert((valueRowMajor ||
(dpasOrder[rank - 2] == 0 && dpasOrder[rank - 1] == 1)) &&
"Only row_major or column_major is allowed");
const bool isTransposeRequired = valueRowMajor ^ memoryRowMajor;

if (isTransposeRequired) {
// TODO: this would likely require a shuffle to match the expected
// ordering coming out of the DPAS layout and requires more
Expand Down Expand Up @@ -675,17 +682,6 @@ struct LoadOpConversion
return success();
}

DotOperandEncodingAttr dotLayout = getDotEncoding(tensorType).value();
auto dotOrder = dotLayout.getThreadOrder();

size_t rank = dotOrder.size();
const bool valueRowMajor =
(dotOrder[rank - 2] == 1 && dotOrder[rank - 1] == 0);
assert((valueRowMajor ||
(dotOrder[rank - 2] == 0 && dotOrder[rank - 1] == 1)) &&
"Only row_major or column_major is allowed");
const bool isTransposeRequired = valueRowMajor ^ memoryRowMajor;

bool isOperandA = (opIdx == DpasEncodingAttr::OpIdx::OperandA);
SmallVector<unsigned> dpasInstShape = isOperandA
? dpasLayout.getDPASInstShapeA()
Expand Down Expand Up @@ -749,8 +745,8 @@ struct LoadOpConversion
offsetBaseY] =
getValuesFromBlockPointerStruct(adaptor.getPtr(), rewriter);

unsigned tileWidth = elemsPerDPASInst[dotOrder[rank - 2]];
unsigned tileHeight = elemsPerDPASInst[dotOrder[rank - 1]];
unsigned tileWidth = elemsPerDPASInst[threadOrder[rank - 2]];
unsigned tileHeight = elemsPerDPASInst[threadOrder[rank - 1]];
unsigned vBlocks = 1;
unsigned numOperandsOuterDimPerLoad = 1;
unsigned numOperandsInnerDimPerLoad = 1;
Expand Down

0 comments on commit d662e65

Please sign in to comment.