guile-knots/knots/parallelism.scm
Christopher Baines 7ba77010ae Handle %stacks not being a pair
Not sure when this would happen, but guard against it.
2025-05-15 09:26:29 +01:00

289 lines
9.8 KiB
Scheme

;;; Guile Knots
;;; Copyright © 2020 Christopher Baines <mail@cbaines.net>
;;;
;;; This file is part of Guile Knots.
;;;
;;; The Guile Knots is free software; you can redistribute it and/or
;;; modify it under the terms of the GNU General Public License as
;;; published by the Free Software Foundation; either version 3 of the
;;; License, or (at your option) any later version.
;;;
;;; The Guile Knots is distributed in the hope that it will be useful,
;;; but WITHOUT ANY WARRANTY; without even the implied warranty of
;;; MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
;;; General Public License for more details.
;;;
;;; You should have received a copy of the GNU General Public License
;;; along with the guix-data-service. If not, see
;;; <http://www.gnu.org/licenses/>.
(define-module (knots parallelism)
#:use-module (srfi srfi-1)
#:use-module (srfi srfi-71)
#:use-module (ice-9 match)
#:use-module (ice-9 control)
#:use-module (ice-9 exceptions)
#:use-module (fibers)
#:use-module (fibers channels)
#:use-module (fibers operations)
#:use-module (knots)
#:export (fibers-batch-map
fibers-map
fibers-map-with-progress
fibers-batch-for-each
fibers-for-each
fibers-parallel
fibers-let
fiberize))
(define (defer-to-parallel-fiber thunk)
(let ((reply (make-channel)))
(spawn-fiber
(lambda ()
(with-exception-handler
(lambda (exn)
(put-message
reply
(list 'exception exn)))
(lambda ()
(with-exception-handler
(lambda (exn)
(let ((stack
(match (fluid-ref %stacks)
((stack-tag . prompt-tag)
(make-stack #t
0 prompt-tag
0 (and prompt-tag 1)))
(_
(make-stack #t)))))
(raise-exception
(make-exception
exn
(make-knots-exception stack)))))
(lambda ()
(call-with-values
(lambda ()
(start-stack #t (thunk)))
(lambda vals
(put-message reply vals))))))
#:unwind? #t))
#:parallel? #t)
reply))
(define (fetch-result-of-defered-thunks . reply-channels)
(let ((responses (map get-message
reply-channels)))
(map
(match-lambda
(('exception exn)
(raise-exception exn))
(result
(apply values result)))
responses)))
(define (fibers-batch-map proc parallelism-limit . lists)
(define vecs (map (lambda (list-or-vec)
(if (vector? list-or-vec)
list-or-vec
(list->vector list-or-vec)))
lists))
(define vecs-length
(vector-length (first vecs)))
(define result-vec
(make-vector vecs-length))
(let loop ((next-to-process-index
(if (= 0 vecs-length)
#f
0))
(channel-indexes '()))
(if (and (eq? #f next-to-process-index)
(null? channel-indexes))
(if (vector? (first lists))
result-vec
(vector->list result-vec))
(if (or (= (length channel-indexes)
(min parallelism-limit vecs-length))
(eq? #f next-to-process-index))
(let ((new-index
new-channel-indexes
(perform-operation
(apply
choice-operation
(map
(lambda (index)
(wrap-operation
(get-operation
(vector-ref result-vec index))
(lambda (result)
(match result
(('exception exn)
(raise-exception exn))
(_
(vector-set! result-vec
index
(first result))
(values next-to-process-index
(lset-difference =
channel-indexes
(list index))))))))
channel-indexes)))))
(loop new-index
new-channel-indexes))
(loop (if (= (+ 1 next-to-process-index)
vecs-length)
#f
(+ 1 next-to-process-index))
(begin
(vector-set!
result-vec
next-to-process-index
(defer-to-parallel-fiber
(lambda ()
(apply proc
(map (lambda (vec)
(vector-ref vec next-to-process-index))
vecs)))))
(cons next-to-process-index
channel-indexes)))))))
(define (fibers-map proc . lists)
(apply fibers-batch-map proc 20 lists))
(define (fibers-batch-for-each proc parallelism-limit . lists)
(apply fibers-batch-map
(lambda args
(apply proc args)
*unspecified*)
parallelism-limit
lists)
*unspecified*)
(define (fibers-for-each proc . lists)
(apply fibers-batch-for-each proc 20 lists))
(define-syntax fibers-parallel
(lambda (x)
(syntax-case x ()
((_ e0 ...)
(with-syntax (((tmp0 ...) (generate-temporaries (syntax (e0 ...)))))
#'(let ((tmp0 (defer-to-parallel-fiber
(lambda ()
e0)))
...)
(apply values (fetch-result-of-defered-thunks tmp0 ...))))))))
(define-syntax-rule (fibers-let ((v e) ...) b0 b1 ...)
(call-with-values
(lambda () (fibers-parallel e ...))
(lambda (v ...)
b0 b1 ...)))
(define* (fibers-map-with-progress proc lists #:key report)
(let loop ((channels-to-results
(apply map
(lambda args
(cons (defer-to-parallel-fiber
(lambda ()
(apply proc args)))
#f))
lists)))
(let ((active-channels
(filter-map car channels-to-results)))
(when report
(report (apply map
list
(map cdr channels-to-results)
lists)))
(if (null? active-channels)
(map
(match-lambda
((#f . ('exception . exn))
(raise-exception exn))
((#f . ('result . val))
val))
channels-to-results)
(loop
(perform-operation
(apply
choice-operation
(filter-map
(lambda (p)
(match p
((channel . _)
(if channel
(wrap-operation
(get-operation channel)
(lambda (result)
(map (match-lambda
((c . r)
(if (eq? channel c)
(cons #f
(match result
(('exception . exn)
result)
(_
(cons 'result result))))
(cons c r))))
channels-to-results)))
#f))))
channels-to-results))))))))
(define* (fiberize proc
#:key (parallelism 1)
(input-channel (make-channel))
(process-channel input-channel))
(for-each
(lambda _
(spawn-fiber
(lambda ()
(while #t
(let ((reply-channel args (car+cdr
(get-message process-channel))))
(put-message
reply-channel
(with-exception-handler
(lambda (exn)
(list 'exception exn))
(lambda ()
(with-exception-handler
(lambda (exn)
(let ((stack
(match (fluid-ref %stacks)
((stack-tag . prompt-tag)
(make-stack #t
0 prompt-tag
0 (and prompt-tag 1)))
(_
(make-stack #t)))))
(raise-exception
(make-exception
exn
(make-knots-exception stack)))))
(lambda ()
(call-with-values
(lambda ()
(start-stack #t (apply proc args)))
(lambda vals
(cons 'result vals))))))
#:unwind? #t)))))
#:parallel? #t))
(iota parallelism))
(lambda args
(let ((reply-channel (make-channel)))
(put-message input-channel (cons reply-channel args))
(match (get-message reply-channel)
(('result . vals) (apply values vals))
(('exception exn)
(raise-exception exn))))))