Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[DRAFT] Capacity aware partitioning #22766

Draft
wants to merge 11 commits into
base: main
Choose a base branch
from
5 changes: 4 additions & 1 deletion include/onnxruntime/core/framework/execution_provider.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@ struct OrtRunOptions;

namespace onnxruntime {

class IResourceAccountant;

/**
Logical device representation.
*/
Expand Down Expand Up @@ -130,7 +132,8 @@ class IExecutionProvider {
*/
virtual std::vector<std::unique_ptr<ComputeCapability>>
GetCapability(const onnxruntime::GraphViewer& graph_viewer,
const IKernelLookup& kernel_lookup) const;
const IKernelLookup& kernel_lookup,
IResourceAccountant* resource_accountant = nullptr) const;

/**
Get kernel registry per execution provider type.
Expand Down
48 changes: 48 additions & 0 deletions include/onnxruntime/core/framework/resource_accountant.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#pragma once

#include <optional>
#include <variant>

namespace onnxruntime {
#ifndef SHARED_PROVIDER
class Graph;
#else
struct Graph;
#endif

// Common holder for potentially different resource accounting
// for different EPs
using ResourceCount = std::variant<size_t>;

/// <summary>
/// This class is used for graph partitioning by EPs
/// It stores the cumulative amount of the resource such as
/// memory that would be consumed by the graph nodes if it is assigned to the EP.
///
/// It provides interfaces to add, remove and query the resource consumption.
///
/// Each provider may assign its own meaning to the resource according to its constraints.
/// </summary>
class IResourceAccountant {
protected:
IResourceAccountant() = default;
IResourceAccountant(const ResourceCount& threshold) : threshold_(threshold) {}

public:
virtual ~IResourceAccountant() = default;
virtual ResourceCount GetConsumedAmount() const = 0;
virtual void AddConsumedAmount(const ResourceCount& amount) = 0;
virtual void RemoveConsumedAmount(const ResourceCount& amount) = 0;
virtual ResourceCount ComputeResourceCount(const Graph&, size_t node_index) const = 0;
std::optional<ResourceCount> GetThreshold() const {
return threshold_;
}

private:
std::optional<ResourceCount> threshold_;
};

} // namespace onnxruntime
7 changes: 7 additions & 0 deletions include/onnxruntime/core/graph/graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -880,6 +880,13 @@ class Graph { // NOLINT(clang-analyzer-optin.performance.Padding): preserve exi
return ConstGraphNodes(nodes_, std::move(filter_func));
}

/** Compute node memory requirements, which is mostly initializers
and large attributes that are copied on device (special cases for some nodes)

Returns no value if the node was not found.
*/
size_t ComputeNodeMemoryUsage(NodeIndex) const;

/** Gets the maximum NodeIndex value used in the Graph.
WARNING: This actually returns the max index value used + 1.
*/
Expand Down
5 changes: 5 additions & 0 deletions include/onnxruntime/core/graph/graph_viewer.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#include <filesystem>

#include "core/graph/graph.h"
#include "core/framework/resource_accountant.h"
yuslepukhin marked this conversation as resolved.
Show resolved Hide resolved
#include "core/framework/session_options.h"

namespace onnxruntime {
Expand Down Expand Up @@ -189,6 +190,10 @@ class GraphViewer {
*/
const IndexedSubGraph* GetFilterInfo() const { return filter_info_; }

size_t ComputeNodeMemoryUsage(NodeIndex node_index) const {
yuslepukhin marked this conversation as resolved.
Show resolved Hide resolved
return graph_->ComputeNodeMemoryUsage(node_index);
}

#if !defined(ORT_MINIMAL_BUILD)
IOnnxRuntimeOpSchemaCollectionPtr GetSchemaRegistry() const { return graph_->GetSchemaRegistry(); }
#endif
Expand Down
49 changes: 49 additions & 0 deletions include/onnxruntime/core/graph/indexed_sub_graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
#include <string>
#include <vector>

#include "core/common/inlined_containers_fwd.h"
#include "core/framework/resource_accountant.h"
#include "core/graph/basic_types.h"
#include "core/graph/onnx_protobuf.h"

Expand Down Expand Up @@ -70,9 +72,56 @@ struct IndexedSubGraph {
return meta_def_.get();
}

// Check if the accounting is enabled for the current EP
bool IsAccountingEnabled() const {
return resource_accountant != nullptr &&
nodes_costs.size() == nodes.size();
}

// Should call IsAccountingEnabled() first
// Takes the previously computed ResourceCount for the node
// (usually during GetCapabiilty())
// if present and adds it to the consumed amount
void AccountForNode(size_t cost_index) const {
assert(cost_index < nodes_costs.size());
if (nodes_costs[cost_index].has_value()) {
resource_accountant->AddConsumedAmount(*nodes_costs[cost_index]);
}
}

// This computes and accounts for the resource cost for the node that just
// been fused from other nodes, and the EP did not had a chance to compute the costs.
void ComputeAndAccountForNode(const Graph& graph, size_t node_index) const {
assert(resource_accountant != nullptr);
resource_accountant->AddConsumedAmount(resource_accountant->ComputeResourceCount(graph, node_index));
}

void SetAccountant(IResourceAccountant* res_accountant) {
resource_accountant = res_accountant;
}

// Append resource count to the list of costs for the nodes.
void AppendNodeCost(const ResourceCount& cost) {
assert(resource_accountant != nullptr);
nodes_costs.emplace_back(cost);
}

// Append an absent cost for the node that was already accounted for.
void AppendNodeEmptyCost() {
assert(resource_accountant != nullptr);
nodes_costs.emplace_back();
}

private:
// subgraph meta definition.
std::unique_ptr<MetaDef> meta_def_;
// Optional resource accountant for this subgraph.
IResourceAccountant* resource_accountant = nullptr;
// Vector with resource costs for nodes above. Should have the same size
// Some nodes that were previously accounted for because they already been assigned to an EP
// for example during multiple calls to GetCapabiility() will not have resource count present.
// may not have a resource count present, we skip it.
InlinedVector<std::optional<ResourceCount>> nodes_costs;
};

} // namespace onnxruntime
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,11 @@ static const char* const kOrtSessionOptionsConfigStrictAllowReleasedOpsetsOnly =
// The file saves configuration for partitioning node among logic streams
static const char* const kNodePartitionConfigFile = "session.node_partition_config_file";

/// "number > 0": enables Capacity Aware Partitioning for Cuda EP. The EP will place nodes on device
/// "0" : disables Capacity Aware Partitioning for Cuda EP. The EP will place nodes on device based on the default policy.
/// until the device memory usage reaches the specified threshold in Kb. The default value is 0.
static const char* const kOrtSessionOptionsConfigPartitionSetCudaMemoryLimitKb = "session.node_partition_memory_limit_kb";

// This Option allows setting affinities for intra op threads.
// Affinity string follows format:
// logical_processor_id,logical_processor_id;logical_processor_id,logical_processor_id
Expand Down
3 changes: 2 additions & 1 deletion onnxruntime/core/framework/execution_provider.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@ namespace onnxruntime {

std::vector<std::unique_ptr<ComputeCapability>>
IExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph,
const IKernelLookup& kernel_lookup) const {
const IKernelLookup& kernel_lookup,
IResourceAccountant*) const {
std::vector<std::unique_ptr<ComputeCapability>> result;
for (const auto& node : graph.Nodes()) {
if (const KernelCreateInfo* kernel_create_info = kernel_lookup.LookUpKernel(node);
Expand Down
Loading
Loading