diff options
-rw-r--r-- | drivers/misc/cxl/api.c | 2 | ||||
-rw-r--r-- | drivers/misc/cxl/context.c | 6 | ||||
-rw-r--r-- | drivers/misc/cxl/cxl.h | 3 | ||||
-rw-r--r-- | drivers/misc/cxl/fault.c | 129 | ||||
-rw-r--r-- | drivers/misc/cxl/file.c | 6 |
5 files changed, 109 insertions, 37 deletions
diff --git a/drivers/misc/cxl/api.c b/drivers/misc/cxl/api.c index a6543aefa299..ea3eeb7011e1 100644 --- a/drivers/misc/cxl/api.c +++ b/drivers/misc/cxl/api.c | |||
@@ -172,7 +172,7 @@ int cxl_start_context(struct cxl_context *ctx, u64 wed, | |||
172 | 172 | ||
173 | if (task) { | 173 | if (task) { |
174 | ctx->pid = get_task_pid(task, PIDTYPE_PID); | 174 | ctx->pid = get_task_pid(task, PIDTYPE_PID); |
175 | get_pid(ctx->pid); | 175 | ctx->glpid = get_task_pid(task->group_leader, PIDTYPE_PID); |
176 | kernel = false; | 176 | kernel = false; |
177 | } | 177 | } |
178 | 178 | ||
diff --git a/drivers/misc/cxl/context.c b/drivers/misc/cxl/context.c index 6dde7a9d6a7e..262b88eac414 100644 --- a/drivers/misc/cxl/context.c +++ b/drivers/misc/cxl/context.c | |||
@@ -42,7 +42,7 @@ int cxl_context_init(struct cxl_context *ctx, struct cxl_afu *afu, bool master, | |||
42 | spin_lock_init(&ctx->sste_lock); | 42 | spin_lock_init(&ctx->sste_lock); |
43 | ctx->afu = afu; | 43 | ctx->afu = afu; |
44 | ctx->master = master; | 44 | ctx->master = master; |
45 | ctx->pid = NULL; /* Set in start work ioctl */ | 45 | ctx->pid = ctx->glpid = NULL; /* Set in start work ioctl */ |
46 | mutex_init(&ctx->mapping_lock); | 46 | mutex_init(&ctx->mapping_lock); |
47 | ctx->mapping = mapping; | 47 | ctx->mapping = mapping; |
48 | 48 | ||
@@ -217,7 +217,11 @@ int __detach_context(struct cxl_context *ctx) | |||
217 | WARN_ON(cxl_detach_process(ctx) && | 217 | WARN_ON(cxl_detach_process(ctx) && |
218 | cxl_adapter_link_ok(ctx->afu->adapter)); | 218 | cxl_adapter_link_ok(ctx->afu->adapter)); |
219 | flush_work(&ctx->fault_work); /* Only needed for dedicated process */ | 219 | flush_work(&ctx->fault_work); /* Only needed for dedicated process */ |
220 | |||
221 | /* release the reference to the group leader and mm handling pid */ | ||
220 | put_pid(ctx->pid); | 222 | put_pid(ctx->pid); |
223 | put_pid(ctx->glpid); | ||
224 | |||
221 | cxl_ctx_put(); | 225 | cxl_ctx_put(); |
222 | return 0; | 226 | return 0; |
223 | } | 227 | } |
diff --git a/drivers/misc/cxl/cxl.h b/drivers/misc/cxl/cxl.h index 25ae57fa79b0..a521bc72cec2 100644 --- a/drivers/misc/cxl/cxl.h +++ b/drivers/misc/cxl/cxl.h | |||
@@ -445,6 +445,9 @@ struct cxl_context { | |||
445 | unsigned int sst_size, sst_lru; | 445 | unsigned int sst_size, sst_lru; |
446 | 446 | ||
447 | wait_queue_head_t wq; | 447 | wait_queue_head_t wq; |
448 | /* pid of the group leader associated with the pid */ | ||
449 | struct pid *glpid; | ||
450 | /* use mm context associated with this pid for ds faults */ | ||
448 | struct pid *pid; | 451 | struct pid *pid; |
449 | spinlock_t lock; /* Protects pending_irq_mask, pending_fault and fault_addr */ | 452 | spinlock_t lock; /* Protects pending_irq_mask, pending_fault and fault_addr */ |
450 | /* Only used in PR mode */ | 453 | /* Only used in PR mode */ |
diff --git a/drivers/misc/cxl/fault.c b/drivers/misc/cxl/fault.c index 25a5418c55cb..81c3f75b7330 100644 --- a/drivers/misc/cxl/fault.c +++ b/drivers/misc/cxl/fault.c | |||
@@ -166,13 +166,92 @@ static void cxl_handle_page_fault(struct cxl_context *ctx, | |||
166 | cxl_ack_irq(ctx, CXL_PSL_TFC_An_R, 0); | 166 | cxl_ack_irq(ctx, CXL_PSL_TFC_An_R, 0); |
167 | } | 167 | } |
168 | 168 | ||
169 | /* | ||
170 | * Returns the mm_struct corresponding to the context ctx via ctx->pid | ||
171 | * In case the task has exited we use the task group leader accessible | ||
172 | * via ctx->glpid to find the next task in the thread group that has a | ||
173 | * valid mm_struct associated with it. If a task with valid mm_struct | ||
174 | * is found the ctx->pid is updated to use the task struct for subsequent | ||
175 | * translations. In case no valid mm_struct is found in the task group to | ||
176 | * service the fault a NULL is returned. | ||
177 | */ | ||
178 | static struct mm_struct *get_mem_context(struct cxl_context *ctx) | ||
179 | { | ||
180 | struct task_struct *task = NULL; | ||
181 | struct mm_struct *mm = NULL; | ||
182 | struct pid *old_pid = ctx->pid; | ||
183 | |||
184 | if (old_pid == NULL) { | ||
185 | pr_warn("%s: Invalid context for pe=%d\n", | ||
186 | __func__, ctx->pe); | ||
187 | return NULL; | ||
188 | } | ||
189 | |||
190 | task = get_pid_task(old_pid, PIDTYPE_PID); | ||
191 | |||
192 | /* | ||
193 | * pid_alive may look racy but this saves us from costly | ||
194 | * get_task_mm when the task is a zombie. In worst case | ||
195 | * we may think a task is alive, which is about to die | ||
196 | * but get_task_mm will return NULL. | ||
197 | */ | ||
198 | if (task != NULL && pid_alive(task)) | ||
199 | mm = get_task_mm(task); | ||
200 | |||
201 | /* release the task struct that was taken earlier */ | ||
202 | if (task) | ||
203 | put_task_struct(task); | ||
204 | else | ||
205 | pr_devel("%s: Context owning pid=%i for pe=%i dead\n", | ||
206 | __func__, pid_nr(old_pid), ctx->pe); | ||
207 | |||
208 | /* | ||
209 | * If we couldn't find the mm context then use the group | ||
210 | * leader to iterate over the task group and find a task | ||
211 | * that gives us mm_struct. | ||
212 | */ | ||
213 | if (unlikely(mm == NULL && ctx->glpid != NULL)) { | ||
214 | |||
215 | rcu_read_lock(); | ||
216 | task = pid_task(ctx->glpid, PIDTYPE_PID); | ||
217 | if (task) | ||
218 | do { | ||
219 | mm = get_task_mm(task); | ||
220 | if (mm) { | ||
221 | ctx->pid = get_task_pid(task, | ||
222 | PIDTYPE_PID); | ||
223 | break; | ||
224 | } | ||
225 | task = next_thread(task); | ||
226 | } while (task && !thread_group_leader(task)); | ||
227 | rcu_read_unlock(); | ||
228 | |||
229 | /* check if we switched pid */ | ||
230 | if (ctx->pid != old_pid) { | ||
231 | if (mm) | ||
232 | pr_devel("%s:pe=%i switch pid %i->%i\n", | ||
233 | __func__, ctx->pe, pid_nr(old_pid), | ||
234 | pid_nr(ctx->pid)); | ||
235 | else | ||
236 | pr_devel("%s:Cannot find mm for pid=%i\n", | ||
237 | __func__, pid_nr(old_pid)); | ||
238 | |||
239 | /* drop the reference to older pid */ | ||
240 | put_pid(old_pid); | ||
241 | } | ||
242 | } | ||
243 | |||
244 | return mm; | ||
245 | } | ||
246 | |||
247 | |||
248 | |||
169 | void cxl_handle_fault(struct work_struct *fault_work) | 249 | void cxl_handle_fault(struct work_struct *fault_work) |
170 | { | 250 | { |
171 | struct cxl_context *ctx = | 251 | struct cxl_context *ctx = |
172 | container_of(fault_work, struct cxl_context, fault_work); | 252 | container_of(fault_work, struct cxl_context, fault_work); |
173 | u64 dsisr = ctx->dsisr; | 253 | u64 dsisr = ctx->dsisr; |
174 | u64 dar = ctx->dar; | 254 | u64 dar = ctx->dar; |
175 | struct task_struct *task = NULL; | ||
176 | struct mm_struct *mm = NULL; | 255 | struct mm_struct *mm = NULL; |
177 | 256 | ||
178 | if (cxl_p2n_read(ctx->afu, CXL_PSL_DSISR_An) != dsisr || | 257 | if (cxl_p2n_read(ctx->afu, CXL_PSL_DSISR_An) != dsisr || |
@@ -195,17 +274,17 @@ void cxl_handle_fault(struct work_struct *fault_work) | |||
195 | "DSISR: %#llx DAR: %#llx\n", ctx->pe, dsisr, dar); | 274 | "DSISR: %#llx DAR: %#llx\n", ctx->pe, dsisr, dar); |
196 | 275 | ||
197 | if (!ctx->kernel) { | 276 | if (!ctx->kernel) { |
198 | if (!(task = get_pid_task(ctx->pid, PIDTYPE_PID))) { | 277 | |
199 | pr_devel("cxl_handle_fault unable to get task %i\n", | 278 | mm = get_mem_context(ctx); |
200 | pid_nr(ctx->pid)); | 279 | /* indicates all the thread in task group have exited */ |
280 | if (mm == NULL) { | ||
281 | pr_devel("%s: unable to get mm for pe=%d pid=%i\n", | ||
282 | __func__, ctx->pe, pid_nr(ctx->pid)); | ||
201 | cxl_ack_ae(ctx); | 283 | cxl_ack_ae(ctx); |
202 | return; | 284 | return; |
203 | } | 285 | } else { |
204 | if (!(mm = get_task_mm(task))) { | 286 | pr_devel("Handling page fault for pe=%d pid=%i\n", |
205 | pr_devel("cxl_handle_fault unable to get mm %i\n", | 287 | ctx->pe, pid_nr(ctx->pid)); |
206 | pid_nr(ctx->pid)); | ||
207 | cxl_ack_ae(ctx); | ||
208 | goto out; | ||
209 | } | 288 | } |
210 | } | 289 | } |
211 | 290 | ||
@@ -218,33 +297,22 @@ void cxl_handle_fault(struct work_struct *fault_work) | |||
218 | 297 | ||
219 | if (mm) | 298 | if (mm) |
220 | mmput(mm); | 299 | mmput(mm); |
221 | out: | ||
222 | if (task) | ||
223 | put_task_struct(task); | ||
224 | } | 300 | } |
225 | 301 | ||
226 | static void cxl_prefault_one(struct cxl_context *ctx, u64 ea) | 302 | static void cxl_prefault_one(struct cxl_context *ctx, u64 ea) |
227 | { | 303 | { |
228 | int rc; | ||
229 | struct task_struct *task; | ||
230 | struct mm_struct *mm; | 304 | struct mm_struct *mm; |
231 | 305 | ||
232 | if (!(task = get_pid_task(ctx->pid, PIDTYPE_PID))) { | 306 | mm = get_mem_context(ctx); |
233 | pr_devel("cxl_prefault_one unable to get task %i\n", | 307 | if (mm == NULL) { |
234 | pid_nr(ctx->pid)); | ||
235 | return; | ||
236 | } | ||
237 | if (!(mm = get_task_mm(task))) { | ||
238 | pr_devel("cxl_prefault_one unable to get mm %i\n", | 308 | pr_devel("cxl_prefault_one unable to get mm %i\n", |
239 | pid_nr(ctx->pid)); | 309 | pid_nr(ctx->pid)); |
240 | put_task_struct(task); | ||
241 | return; | 310 | return; |
242 | } | 311 | } |
243 | 312 | ||
244 | rc = cxl_fault_segment(ctx, mm, ea); | 313 | cxl_fault_segment(ctx, mm, ea); |
245 | 314 | ||
246 | mmput(mm); | 315 | mmput(mm); |
247 | put_task_struct(task); | ||
248 | } | 316 | } |
249 | 317 | ||
250 | static u64 next_segment(u64 ea, u64 vsid) | 318 | static u64 next_segment(u64 ea, u64 vsid) |
@@ -263,18 +331,13 @@ static void cxl_prefault_vma(struct cxl_context *ctx) | |||
263 | struct copro_slb slb; | 331 | struct copro_slb slb; |
264 | struct vm_area_struct *vma; | 332 | struct vm_area_struct *vma; |
265 | int rc; | 333 | int rc; |
266 | struct task_struct *task; | ||
267 | struct mm_struct *mm; | 334 | struct mm_struct *mm; |
268 | 335 | ||
269 | if (!(task = get_pid_task(ctx->pid, PIDTYPE_PID))) { | 336 | mm = get_mem_context(ctx); |
270 | pr_devel("cxl_prefault_vma unable to get task %i\n", | 337 | if (mm == NULL) { |
271 | pid_nr(ctx->pid)); | ||
272 | return; | ||
273 | } | ||
274 | if (!(mm = get_task_mm(task))) { | ||
275 | pr_devel("cxl_prefault_vm unable to get mm %i\n", | 338 | pr_devel("cxl_prefault_vm unable to get mm %i\n", |
276 | pid_nr(ctx->pid)); | 339 | pid_nr(ctx->pid)); |
277 | goto out1; | 340 | return; |
278 | } | 341 | } |
279 | 342 | ||
280 | down_read(&mm->mmap_sem); | 343 | down_read(&mm->mmap_sem); |
@@ -295,8 +358,6 @@ static void cxl_prefault_vma(struct cxl_context *ctx) | |||
295 | up_read(&mm->mmap_sem); | 358 | up_read(&mm->mmap_sem); |
296 | 359 | ||
297 | mmput(mm); | 360 | mmput(mm); |
298 | out1: | ||
299 | put_task_struct(task); | ||
300 | } | 361 | } |
301 | 362 | ||
302 | void cxl_prefault(struct cxl_context *ctx, u64 wed) | 363 | void cxl_prefault(struct cxl_context *ctx, u64 wed) |
diff --git a/drivers/misc/cxl/file.c b/drivers/misc/cxl/file.c index 5cc14599837d..783337d22f36 100644 --- a/drivers/misc/cxl/file.c +++ b/drivers/misc/cxl/file.c | |||
@@ -201,8 +201,12 @@ static long afu_ioctl_start_work(struct cxl_context *ctx, | |||
201 | * where a process (master, some daemon, etc) has opened the chardev on | 201 | * where a process (master, some daemon, etc) has opened the chardev on |
202 | * behalf of another process, so the AFU's mm gets bound to the process | 202 | * behalf of another process, so the AFU's mm gets bound to the process |
203 | * that performs this ioctl and not the process that opened the file. | 203 | * that performs this ioctl and not the process that opened the file. |
204 | * Also we grab the PID of the group leader so that if the task that | ||
205 | * has performed the attach operation exits the mm context of the | ||
206 | * process is still accessible. | ||
204 | */ | 207 | */ |
205 | ctx->pid = get_pid(get_task_pid(current, PIDTYPE_PID)); | 208 | ctx->pid = get_task_pid(current, PIDTYPE_PID); |
209 | ctx->glpid = get_task_pid(current->group_leader, PIDTYPE_PID); | ||
206 | 210 | ||
207 | trace_cxl_attach(ctx, work.work_element_descriptor, work.num_interrupts, amr); | 211 | trace_cxl_attach(ctx, work.work_element_descriptor, work.num_interrupts, amr); |
208 | 212 | ||