diff --git a/include/ctranslate2/storage_view.h b/include/ctranslate2/storage_view.h index 8834ef651..d25d47048 100644 --- a/include/ctranslate2/storage_view.h +++ b/include/ctranslate2/storage_view.h @@ -238,6 +238,9 @@ namespace ctranslate2 { friend std::ostream& operator<<(std::ostream& os, const StorageView& storage); + template + void print_tensor(std::ostream& os, const T* data, const std::vector& shape, size_t dim, size_t offset, int indent) const; + protected: DataType _dtype = DataType::FLOAT32; Device _device = Device::CPU; diff --git a/src/storage_view.cc b/src/storage_view.cc index 0cbdf25c2..13f973bf8 100644 --- a/src/storage_view.cc +++ b/src/storage_view.cc @@ -440,44 +440,77 @@ namespace ctranslate2 { return os; } - std::ostream& operator<<(std::ostream& os, const StorageView& storage) { - StorageView printable(storage.dtype()); - printable.copy_from(storage); - TYPE_DISPATCH( - printable.dtype(), - const auto* values = printable.data(); - if (printable.size() <= PRINT_MAX_VALUES) { - for (dim_t i = 0; i < printable.size(); ++i) { - os << ' '; - print_value(os, values[i]); + std::ostream& operator<<(std::ostream& os, const StorageView& storage) { + // Create a printable copy of the storage + StorageView printable(storage.dtype()); + printable.copy_from(storage); + + // Check the data type and print accordingly + TYPE_DISPATCH( + printable.dtype(), + const auto* values = printable.data(); + const auto& shape = printable.shape(); + + // Print tensor contents based on dimensionality + if (shape.empty()) { // Scalar case + os << "Data (Scalar): " << values[0] << std::endl; + } else { + os << "Data (" << shape.size() << "D "; + if (shape.size() == 1) + os << "Vector"; + else if (shape.size() == 2) + os << "Matrix"; + else + os << "Tensor"; + os << "):" << std::endl; + printable.print_tensor(os, values, shape, 0, 0, 0); + os << std::endl; + } + ); + + // Print metadata + os << "[device:" << device_to_str(storage.device(), storage.device_index()) + << ", dtype:" << dtype_name(storage.dtype()) << ", storage viewed as "; + if (storage.is_scalar()) + os << "scalar"; + else { + for (dim_t i = 0; i < storage.rank(); ++i) { + if (i > 0) + os << 'x'; + os << storage.dim(i); } } - else { - for (dim_t i = 0; i < PRINT_MAX_VALUES / 2; ++i) { - os << ' '; - print_value(os, values[i]); + os << ']'; + return os; + } + + template + void StorageView::print_tensor(std::ostream& os, const T* data, const std::vector& shape, size_t dim, size_t offset, int indent) const { + std::string indentation(indent, ' '); + + os << indentation << "["; + bool is_last_dim = (dim == shape.size() - 1); + for (dim_t i = 0; i < shape[dim]; ++i) { + if (i > 0) { + os << ", "; + if (!is_last_dim) { + os << "\n" << std::string(indent, ' '); + } } - os << " ..."; - for (dim_t i = printable.size() - (PRINT_MAX_VALUES / 2); i < printable.size(); ++i) { - os << ' '; - print_value(os, values[i]); + + if (i == PRINT_MAX_VALUES / 2) { + os << "..."; + i = shape[dim] - PRINT_MAX_VALUES / 2 - 1; // Skip to the last part + } else { + if (is_last_dim) { + os << data[offset + i]; + } else { + print_tensor(os, data, shape, dim + 1, offset + i * shape[dim + 1], indent); + } } } - os << std::endl); - os << "[" << device_to_str(storage.device(), storage.device_index()) - << " " << dtype_name(storage.dtype()) << " storage viewed as "; - if (storage.is_scalar()) - os << "scalar"; - else { - for (dim_t i = 0; i < storage.rank(); ++i) { - if (i > 0) - os << 'x'; - os << storage.dim(i); - } + os << "]"; } - os << ']'; - return os; - } #define DECLARE_IMPL(T) \ template \