aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorJoshua Bakita <bakitajoshua@gmail.com>2023-11-29 18:05:01 -0500
committerJoshua Bakita <bakitajoshua@gmail.com>2023-11-29 18:24:25 -0500
commit973b919cfe6d05fdb3b82f538b1afbc3233a7008 (patch)
tree0964c9a9d94345dce6def66c1812e0204374b1ef
parent8062646a185baa6d3934d1e19743ac671e943fa8 (diff)
Abort process on error, and better document callback-based masking
-rw-r--r--libsmctrl.c42
1 files changed, 25 insertions, 17 deletions
diff --git a/libsmctrl.c b/libsmctrl.c
index 526331f..817cb5d 100644
--- a/libsmctrl.c
+++ b/libsmctrl.c
@@ -156,19 +156,29 @@ static uint64_t g_sm_mask = 0;
156static __thread uint64_t g_next_sm_mask = 0; 156static __thread uint64_t g_next_sm_mask = 0;
157static char sm_control_setup_called = 0; 157static char sm_control_setup_called = 0;
158static void launchCallback(void *ukwn, int domain, int cbid, const void *in_params) { 158static void launchCallback(void *ukwn, int domain, int cbid, const void *in_params) {
159 if (*(uint32_t*)in_params < 0x50) { 159 // The third 8-byte element in `in_parms` is a pointer to the stream struct.
160 fprintf(stderr, "Unsupported CUDA version for callback-based SM masking. Aborting...\n"); 160 // This exists even when in_params < 0x50. This could be used to implement
161 return; 161 // stream masking without the manual offsets specified elsewhere (store a
162 } 162 // table of stream pointers to masks and do a lookup here).
163 if (!**((uintptr_t***)in_params+8)) { 163 // It could also be used (although not as easily) to support global and next
164 fprintf(stderr, "Called with NULL halLaunchDataAllocation\n"); 164 // masking on old CUDA versions, but that would require hooking earlier in the
165 return; 165 // launch process (before the stream mask is applied).
166 } 166 if (*(uint32_t*)in_params < 0x50)
167 //fprintf(stderr, "cta: %lx\n", *(uint64_t*)(**((char***)in_params + 8) + 74)); 167 abort(1, 0, "Unsupported CUDA version for callback-based SM masking. Aborting...");
168 // The eighth 8-byte element in `in_params` is a pointer to a struct which
169 // contains a pointer to the TMD as its first element. Note that this eighth
170 // pointer must exist---it only exists when the first 8-byte element of
171 // `in_params` is at least 0x50 (checked above).
172 void* tmd = **((uintptr_t***)in_params + 8);
173 if (!tmd)
174 abort(1, 0, "TMD allocation appears NULL; likely forward-compatibilty issue.\n");
175
176 //fprintf(stderr, "cta: %lx\n", *(uint64_t*)(tmd + 74));
168 // TODO: Check for supported QMD version (>XXX, <4.00) 177 // TODO: Check for supported QMD version (>XXX, <4.00)
169 // TODO: Support QMD version 4 (Hopper), where offset starts at +304 (rather than +84) and is 72 bytes (rather than 8 bytes) wide 178 // TODO: Support QMD version 4 (Hopper), where offset starts at +304 (rather than +84) and is 16 bytes (rather than 8 bytes) wide. It also requires an enable bit at +31bits.
170 uint32_t *lower_ptr = (uint32_t*)(**((char***)in_params + 8) + 84); 179 uint32_t *lower_ptr = tmd + 84;
171 uint32_t *upper_ptr = (uint32_t*)(**((char***)in_params + 8) + 88); 180 uint32_t *upper_ptr = tmd + 88;
181
172 if (g_next_sm_mask) { 182 if (g_next_sm_mask) {
173 *lower_ptr = (uint32_t)g_next_sm_mask; 183 *lower_ptr = (uint32_t)g_next_sm_mask;
174 *upper_ptr = (uint32_t)(g_next_sm_mask >> 32); 184 *upper_ptr = (uint32_t)(g_next_sm_mask >> 32);
@@ -198,13 +208,11 @@ static void setup_sm_control_11() {
198 enable = (typeof(enable))enable_func_addr; 208 enable = (typeof(enable))enable_func_addr;
199 int res = 0; 209 int res = 0;
200 res = subscribe(&my_hndl, launchCallback, NULL); 210 res = subscribe(&my_hndl, launchCallback, NULL);
201 if (res) { 211 if (res)
202 fprintf(stderr, "libsmctrl: Error subscribing to launch callback. Error %d\n", res); 212 abort(1, 0, "Error subscribing to launch callback. CUDA returned error code %d.", res);
203 return;
204 }
205 res = enable(1, my_hndl, LAUNCH_DOMAIN, LAUNCH_PRE_UPLOAD); 213 res = enable(1, my_hndl, LAUNCH_DOMAIN, LAUNCH_PRE_UPLOAD);
206 if (res) 214 if (res)
207 fprintf(stderr, "libsmctrl: Error enabling launch callback. Error %d\n", res); 215 abort(1, 0, "Error enabling launch callback. CUDA returned error code %d.", res);
208} 216}
209 217
210// Set default mask for all launches 218// Set default mask for all launches