Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Port and run tests in python/test/unit/tools #2953

Open
wants to merge 12 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,9 @@ python/triton/backends/
# Language extras
python/triton/language/extra

# Tools extras
python/triton/tools/extra

# Proton
python/triton/profiler

Expand Down
57 changes: 40 additions & 17 deletions python/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,11 @@ class Backend:
name: str
package_data: List[str]
language_package_data: List[str]
tools_package_data: List[str]
src_dir: str
backend_dir: str
language_dir: Optional[str]
tools_dir: Optional[str]
install_dir: str
is_external: bool

Expand Down Expand Up @@ -68,6 +70,10 @@ def prepare(backend_name: str, backend_src_dir: str = None, is_external: bool =
if not os.path.exists(language_dir):
language_dir = None

tools_dir = os.path.abspath(os.path.join(backend_src_dir, "tools"))
if not os.path.exists(tools_dir):
tools_dir = None

for file in ["compiler.py", "driver.py"]:
assert os.path.exists(os.path.join(backend_path, file)), f"${file} does not exist in ${backend_path}"

Expand All @@ -78,9 +84,12 @@ def prepare(backend_name: str, backend_src_dir: str = None, is_external: bool =
if language_dir is not None:
language_package_data = [f"{os.path.relpath(p, language_dir)}/*" for p, _, _, in os.walk(language_dir)]

tools_package_data = []
if tools_dir is not None:
tools_package_data = [f"{os.path.relpath(p, tools_dir)}/*" for p, _, _, in os.walk(tools_dir)]
return Backend(name=backend_name, package_data=package_data, language_package_data=language_package_data,
src_dir=backend_src_dir, backend_dir=backend_path, language_dir=language_dir,
install_dir=install_dir, is_external=is_external)
tools_package_data=tools_package_data, src_dir=backend_src_dir, backend_dir=backend_path,
language_dir=language_dir, tools_dir=tools_dir, install_dir=install_dir, is_external=is_external)

# Copy all in-tree backends under triton/third_party.
@staticmethod
Expand Down Expand Up @@ -600,6 +609,15 @@ def add_link_to_backends():
install_dir = os.path.join(extra_dir, x)
update_symlink(install_dir, src_dir)

if backend.tools_dir:
# Link the contents of each backend's `tools` directory into
# `triton.tools.extra`.
extra_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "triton", "tools", "extra"))
for x in os.listdir(backend.tools_dir):
src_dir = os.path.join(backend.tools_dir, x)
install_dir = os.path.join(extra_dir, x)
update_symlink(install_dir, src_dir)


def add_link_to_proton():
proton_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), os.pardir, "third_party", "proton", "proton"))
Expand Down Expand Up @@ -642,28 +660,31 @@ def run(self):


package_data = {
"triton/tools": ["compile.h", "compile.c"], **{f"triton/backends/{b.name}": b.package_data
for b in backends}, "triton/language/extra": sum(
(b.language_package_data for b in backends), [])
"triton/tools/extra": sum((b.tools_package_data for b in backends), []),
**{f"triton/backends/{b.name}": b.package_data
etiotto marked this conversation as resolved.
Show resolved Hide resolved
for b in backends}, "triton/language/extra": sum((b.language_package_data for b in backends), [])
}


def get_language_extra_packages():
def get_extra_packages(extra_name):
packages = []
extra_file_extensions = {"language": (".py"), "tools": (".c", ".h", ".cpp")}
assert extra_name in extra_file_extensions, f"{extra_name} extra is not valid"

for backend in backends:
if backend.language_dir is None:
backend_extra_dir = getattr(backend, f"{extra_name}_dir", None)
if backend_extra_dir is None:
continue

# Walk the `language` directory of each backend to enumerate
# any subpackages, which will be added to `triton.language.extra`.
for dir, dirs, files in os.walk(backend.language_dir, followlinks=True):
if not any(f for f in files if f.endswith(".py")) or dir == backend.language_dir:
# Ignore directories with no python files.
# Also ignore the root directory which corresponds to
# "triton/language/extra".
# Walk the specified directory of each backend to enumerate
# any subpackages, which will be added to extra_package.
for dir, dirs, files in os.walk(backend_extra_dir, followlinks=True):
if not any(f for f in files if f.endswith(extra_file_extensions[extra_name])) or dir == backend_extra_dir:
# Ignore directories with no relevant files
# or the root directory
continue
subpackage = os.path.relpath(dir, backend.language_dir)
package = os.path.join("triton/language/extra", subpackage)
subpackage = os.path.relpath(dir, backend_extra_dir)
package = os.path.join(f"triton/{extra_name}/extra", subpackage)
packages.append(package)

return list(packages)
Expand All @@ -679,9 +700,11 @@ def get_packages():
"triton/runtime",
"triton/backends",
"triton/tools",
"triton/tools/extra",
]
packages += [f'triton/backends/{backend.name}' for backend in backends]
packages += get_language_extra_packages()
packages += get_extra_packages("language")
packages += get_extra_packages("tools")
if check_env_flag("TRITON_BUILD_PROTON", "ON"): # Default ON
packages += ["triton/profiler"]

Expand Down
180 changes: 163 additions & 17 deletions python/test/unit/tools/test_aot.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,13 @@
import tempfile

import numpy as np
import pytest
etiotto marked this conversation as resolved.
Show resolved Hide resolved

import triton
from triton._internal_testing import is_cuda, is_xpu
from triton.backends.compiler import GPUTarget
from triton.backends.nvidia.driver import include_dir, library_dirs
from triton.backends.intel.driver import compilation_helper

kernel_utils_src = """
import triton
Expand Down Expand Up @@ -97,21 +100,42 @@ def kernel(C, A, B, M, N, K,
}"""


def gen_kernel_library(dir, libname):
c_files = glob.glob(os.path.join(dir, "*.c"))
def gen_kernel_library_xpu(dir, libname):
cpp_files = glob.glob(os.path.join(dir, "*.cpp"))
subprocess.run(
["gcc"] + c_files + ["-I", include_dir[0], "-c", "-fPIC"],
["icpx"] + cpp_files + ["-I", compilation_helper.include_dir[0], "-c", "-fsycl", "-fPIC"],
check=True,
cwd=dir,
)
o_files = glob.glob(os.path.join(dir, "*.o"))

command = ["gcc", *o_files, "-shared", "-o", libname]
for lib_dir in library_dirs():
command = ["icpx", "-fsycl", "-lze_loader", *o_files, "-shared", "-o", libname]
for lib_dir in compilation_helper.library_dir:
command.extend(["-L", lib_dir])
if compilation_helper.libsycl_dir:
for lib_dir in compilation_helper.libsycl_dir:
command.extend(["-L", lib_dir])
subprocess.run(command, check=True, cwd=dir)


def gen_kernel_library(dir, libname):
if is_xpu():
gen_kernel_library_xpu(dir, libname)
else:
c_files = glob.glob(os.path.join(dir, "*.c"))
subprocess.run(
["gcc"] + c_files + ["-I", include_dir[0], "-c", "-fPIC"],
check=True,
cwd=dir,
)
o_files = glob.glob(os.path.join(dir, "*.o"))

command = ["gcc", *o_files, "-shared", "-o", libname]
for lib_dir in library_dirs():
command.extend(["-L", lib_dir])
subprocess.run(command, check=True, cwd=dir)


def gen_test_bin(dir, M, N, K, exe="test", algo_id=0):
test_src = f"""
int main(int argc, char **argv) {{
Expand Down Expand Up @@ -171,15 +195,118 @@ def gen_test_bin(dir, M, N, K, exe="test", algo_id=0):
}}
"""
src = test_utils_src + test_src
with open(os.path.join(dir, "test.c"), "w") as file:
if is_xpu():
src = f"""
#include "kernel.h"
#include <assert.h>
#include <cmath>
#include <cstddef>
#include <level_zero/ze_api.h>
#include <stdint.h>
#include <stdio.h>
#include <string.h>
#include <sycl/sycl.hpp>

static void write_buffer_to_csv(char *filename, int32_t *buffer, int size) {{
FILE *file = fopen(filename, "w");
if (file == NULL) {{
printf("Could not open file %s\\n", filename);
return;
}}
for (int i = 0; i < size; i++) {{
fprintf(file, "%d", buffer[i]);
if (i < size - 1) {{
fprintf(file, ",");
}}
}}
fclose(file);
}}

static void read_csv_to_buffer(char *filename, int16_t *buffer, int size) {{
FILE *file = fopen(filename, "r");
if (file == NULL) {{
printf("Could not open file %s\\n", filename);
return;
}}
int index = 0;
while (fscanf(file, "%hd,", &buffer[index]) != EOF && index < size) {{
index++;
}}
fclose(file);
}}
int main(int argc, char ** argv) {{
int M = {M}, N = {N}, K = {K};

// initialize sycl handles
sycl::queue q{{sycl::gpu_selector_v}};
sycl::ext::intel::device_ptr<sycl::float16> A =
sycl::malloc_device<sycl::float16>(M * K * 2, q);
sycl::ext::intel::device_ptr<sycl::float16> B =
sycl::malloc_device<sycl::float16>(K * N * 2, q);
sycl::ext::intel::device_ptr<sycl::float16> C =
sycl::malloc_device<sycl::float16>(M * N * 4, q);

// initialize input data
int16_t hA[M * K];
int16_t hB[K * N];
memset(hA, 0, M * K * 2);
memset(hB, 0, K * N * 2);
read_csv_to_buffer(argv[1], hA, M * K);
read_csv_to_buffer(argv[2], hB, K * N);
q.memcpy(A, hA, M * K * 2).wait();
q.memcpy(B, hB, K * N * 2).wait();

// launch kernel
load_matmul_fp16();
int32_t ret;
int algo_id = {algo_id};
if (algo_id == 0) {{
ret = matmul_fp16_default(q, C, A, B, M, N, K, N, 1, K, 1, N, 1);
}} else {{
ret = matmul_fp16(q, C, A, B, M, N, K, N, 1, K, 1, N, 1, {algo_id});
}}
if (ret != 0) fprintf(stderr, "kernel launch failed\\n");
assert(ret == 0);

q.wait();

// read data
int32_t hC[M * N];
memset(hC, 0, M * N * 4);
q.memcpy(hC, C, M * N * 4).wait();
write_buffer_to_csv(argv[3], hC, M * N);

// free sycl resources
unload_matmul_fp16();
sycl::free(A, q);
sycl::free(B, q);
sycl::free(C, q);
}}
"""
src_name = "test.c"
if is_xpu():
src_name = "test.cpp"
with open(os.path.join(dir, src_name), "w") as file:
file.write(src)

command = ["gcc", "test.c"]
for inc_dir in include_dir:
command.extend(["-I", inc_dir])
for lib_dir in library_dirs():
command.extend(["-L", lib_dir])
command.extend(["-l", "cuda", "-L", dir, "-l", "kernel", "-o", exe])
if is_cuda():
command = ["gcc", "test.c"]
for inc_dir in include_dir:
command.extend(["-I", inc_dir])
for lib_dir in library_dirs():
command.extend(["-L", lib_dir])
command.extend(["-l", "cuda", "-L", dir, "-l", "kernel", "-o", exe])

if is_xpu():
command = ["icpx", "test.cpp"]
for inc_dir in compilation_helper.include_dir:
command.extend(["-I", inc_dir])
for lib_dir in compilation_helper.library_dir:
command.extend(["-L", lib_dir])
if compilation_helper.libsycl_dir:
for lib_dir in compilation_helper.libsycl_dir:
command.extend(["-L", lib_dir])
command.extend(["-fsycl", "-lze_loader", "-L", dir, "-l", "kernel", "-o", exe])
subprocess.run(command, check=True, cwd=dir)


Expand Down Expand Up @@ -283,6 +410,8 @@ def test_compile_link_matmul_no_specialization():

with tempfile.TemporaryDirectory() as tmp_dir:
dtype = "fp16"
if is_xpu():
pytest.skip("FIXME: AssertionError on XPU")
BM, BN, BK = 16, 16, 16

kernel_path = write_triton_kernels(tmp_dir, kernel_src, kernel_utils_src)
Expand All @@ -299,9 +428,8 @@ def test_compile_link_matmul_no_specialization():

# run test case
env = os.environ.copy()
env["LD_LIBRARY_PATH"] = tmp_dir
env["LD_LIBRARY_PATH"] = tmp_dir + ":" + env.get("LD_LIBRARY_PATH", "")
etiotto marked this conversation as resolved.
Show resolved Hide resolved
subprocess.run(["./test", a_path, b_path, c_path], env=env, check=True, cwd=tmp_dir)

# read data and compare against reference
c = np.genfromtxt(c_path, delimiter=",", dtype=np.int32)
c_tri = c.reshape((M, N)).view(np.float32)
Expand Down Expand Up @@ -330,7 +458,7 @@ def test_compile_link_matmul():

# run test case
env = os.environ.copy()
env["LD_LIBRARY_PATH"] = tmp_dir
env["LD_LIBRARY_PATH"] = tmp_dir + ":" + env.get("LD_LIBRARY_PATH", "")
subprocess.run(["./test", a_path, b_path, c_path], env=env, check=True, cwd=tmp_dir)

# read data and compare against reference
Expand Down Expand Up @@ -361,7 +489,7 @@ def test_launcher_has_no_available_kernel():

# run test case
env = os.environ.copy()
env["LD_LIBRARY_PATH"] = tmp_dir
env["LD_LIBRARY_PATH"] = tmp_dir + ":" + env.get("LD_LIBRARY_PATH", "")
result = subprocess.run(
["./test", a_path, b_path, c_path],
env=env,
Expand Down Expand Up @@ -410,7 +538,7 @@ def test_compile_link_autotune_matmul():
gen_test_bin(tmp_dir, M, N, K, exe=test_name, algo_id=algo_id)

env = os.environ.copy()
env["LD_LIBRARY_PATH"] = tmp_dir
env["LD_LIBRARY_PATH"] = tmp_dir + ":" + env.get("LD_LIBRARY_PATH", "")
subprocess.run(
[f"./{test_name}", a_path, b_path, c_path],
check=True,
Expand Down Expand Up @@ -440,3 +568,21 @@ def test_ttgir_to_ptx():
ptx = k.asm["ptx"]
assert ".target sm_80" in ptx
assert ".address_size 64" in ptx


def test_ttgir_to_spv():
src = """
module attributes {"ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32, "ttg.num-ctas" = 1 : i32} {
tt.func public @sum_kernel_0d1d(%arg0: !tt.ptr<i32>, %arg1: !tt.ptr<i32>) {
tt.return
}
}
"""
with tempfile.TemporaryDirectory() as tmp_dir:
kernel_path = os.path.join(tmp_dir, "empty_kernel.ttgir")
with open(kernel_path, "w") as fp:
fp.write(src)
k = triton.compile(kernel_path, target=triton.runtime.driver.active.get_current_target())
spv = k.asm['spvdis']
assert "OpCapability KernelAttributesINTEL" in spv
assert "SubgroupSize 32" in spv
16 changes: 16 additions & 0 deletions python/test/unit/tools/test_disasm.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,19 @@ def kernel(X, i: tl.constexpr):
sass = h.asm["sass"]
# check that the sass has a store instruction.
assert "STG.E" in sass


def test_disam_spvbin():
if not triton.runtime.driver.active.get_current_target().backend == "xpu":
pytest.skip("Test requires XPU.")

@triton.jit
def kernel(X, i: tl.constexpr):
tl.store(X, i)

x = torch.empty(1, dtype=torch.int32, device='xpu')
h = kernel[(1, )](x, i=12)
assert x[0] == 12
dis = h.asm["spvdis"]
# check that the spvdis has a store instruction.
assert "OpStore" in dis
Loading