1515using namespace ::testing;
1616using torch::executor::Error;
1717using torch::executor::KernelRuntimeContext;
18+ using torch::executor::MemoryAllocator;
19+ using torch::executor::Result;
1820
1921class KernelRuntimeContextTest : public ::testing::Test {
2022 public:
@@ -23,6 +25,17 @@ class KernelRuntimeContextTest : public ::testing::Test {
2325 }
2426};
2527
28+ class TestMemoryAllocator : public MemoryAllocator {
29+ public:
30+ TestMemoryAllocator (uint32_t size, uint8_t * base_address)
31+ : MemoryAllocator(size, base_address), last_seen_alignment(0 ) {}
32+ void * allocate (size_t size, size_t alignment) {
33+ last_seen_alignment = alignment;
34+ return MemoryAllocator::allocate (size, alignment);
35+ }
36+ size_t last_seen_alignment;
37+ };
38+
2639TEST_F (KernelRuntimeContextTest, FailureStateDefaultsToOk) {
2740 KernelRuntimeContext context;
2841
@@ -47,3 +60,43 @@ TEST_F(KernelRuntimeContextTest, FailureStateReflectsFailure) {
4760 context.fail (Error::Ok);
4861 EXPECT_EQ (context.failure_state (), Error::Ok);
4962}
63+
64+ TEST_F (KernelRuntimeContextTest, FailureNoMemoryAllocatorProvided) {
65+ KernelRuntimeContext context;
66+ Result<void *> allocated_memory = context.allocate_temp (4 );
67+ EXPECT_EQ (allocated_memory.error (), Error::NotFound);
68+ }
69+
70+ TEST_F (KernelRuntimeContextTest, SuccessfulMemoryAllocation) {
71+ constexpr size_t temp_memory_allocator_pool_size = 4 ;
72+ auto temp_memory_allocator_pool =
73+ std::make_unique<uint8_t []>(temp_memory_allocator_pool_size);
74+ MemoryAllocator temp_allocator (
75+ temp_memory_allocator_pool_size, temp_memory_allocator_pool.get ());
76+ KernelRuntimeContext context (nullptr , &temp_allocator);
77+ Result<void *> allocated_memory = context.allocate_temp (4 );
78+ EXPECT_EQ (allocated_memory.ok (), true );
79+ }
80+
81+ TEST_F (KernelRuntimeContextTest, FailureMemoryAllocationInsufficientSpace) {
82+ constexpr size_t temp_memory_allocator_pool_size = 4 ;
83+ auto temp_memory_allocator_pool =
84+ std::make_unique<uint8_t []>(temp_memory_allocator_pool_size);
85+ MemoryAllocator temp_allocator (
86+ temp_memory_allocator_pool_size, temp_memory_allocator_pool.get ());
87+ KernelRuntimeContext context (nullptr , &temp_allocator);
88+ Result<void *> allocated_memory = context.allocate_temp (8 );
89+ EXPECT_EQ (allocated_memory.error (), Error::MemoryAllocationFailed);
90+ }
91+
92+ TEST_F (KernelRuntimeContextTest, MemoryAllocatorAlignmentPassed) {
93+ constexpr size_t temp_memory_allocator_pool_size = 4 ;
94+ auto temp_memory_allocator_pool =
95+ std::make_unique<uint8_t []>(temp_memory_allocator_pool_size);
96+ TestMemoryAllocator temp_allocator (
97+ temp_memory_allocator_pool_size, temp_memory_allocator_pool.get ());
98+ KernelRuntimeContext context (nullptr , &temp_allocator);
99+ Result<void *> allocated_memory = context.allocate_temp (4 , 2 );
100+ EXPECT_EQ (allocated_memory.ok (), true );
101+ EXPECT_EQ (temp_allocator.last_seen_alignment , 2 );
102+ }
0 commit comments