diff --git a/knots/parallelism.scm b/knots/parallelism.scm index b5b03d5..9331757 100644 --- a/knots/parallelism.scm +++ b/knots/parallelism.scm @@ -35,25 +35,6 @@ fibers-parallel fibers-let)) -;; Like split-at, but don't care about the order of the resulting lists, and -;; don't error if the list is shorter than i elements -(define (split-at* lst i) - (let lp ((l lst) (n i) (acc '())) - (if (or (<= n 0) (null? l)) - (values (reverse! acc) l) - (lp (cdr l) (- n 1) (cons (car l) acc))))) - -;; As this can be called with lists with tens of thousands of items in them, -;; batch the -(define (get-batch batch-size lists) - (let ((split-lists - (map (lambda (lst) - (let ((batch rest (split-at* lst batch-size))) - (cons batch rest))) - lists))) - (values (map car split-lists) - (map cdr split-lists)))) - (define (defer-to-parallel-fiber thunk) (let ((reply (make-channel))) (spawn-fiber @@ -85,46 +66,86 @@ (apply values result))) responses))) -(define (fibers-batch-map proc batch-size . lists) - (let loop ((lists lists) - (result '())) - (let ((batch - rest - (get-batch batch-size lists))) - (if (any null? batch) - result - (let ((response-channels - (apply map - (lambda args - (defer-to-parallel-fiber - (lambda () - (apply proc args)))) - batch))) - (loop rest - (append! result - (apply fetch-result-of-defered-thunks - response-channels)))))))) +(define (fibers-batch-map proc parallelism-limit . lists) + (define vecs (map (lambda (list-or-vec) + (if (vector? list-or-vec) + list-or-vec + (list->vector list-or-vec))) + lists)) + + (define vecs-length + (vector-length (first vecs))) + + (define result-vec + (make-vector vecs-length)) + + (let loop ((next-to-process-index 0) + (channel-indexes '())) + (if (and (eq? #f next-to-process-index) + (null? channel-indexes)) + (if (vector? (first lists)) + result-vec + (vector->list result-vec)) + + (if (or (= (length channel-indexes) + (min parallelism-limit vecs-length)) + (eq? #f next-to-process-index)) + (let ((new-index + new-channel-indexes + (perform-operation + (apply + choice-operation + (map + (lambda (index) + (wrap-operation + (get-operation + (vector-ref result-vec index)) + (lambda (result) + (match result + (('exception . exn) + (raise-exception exn)) + (_ + (vector-set! result-vec + index + (first result)) + + (values next-to-process-index + (lset-difference = + channel-indexes + (list index)))))))) + channel-indexes))))) + (loop new-index + new-channel-indexes)) + + (loop (if (= (+ 1 next-to-process-index) + vecs-length) + #f + (+ 1 next-to-process-index)) + (begin + (vector-set! + result-vec + next-to-process-index + (defer-to-parallel-fiber + (lambda () + (apply proc + (map (lambda (vec) + (vector-ref vec next-to-process-index)) + vecs))))) + (cons next-to-process-index + channel-indexes))))))) (define (fibers-map proc . lists) (apply fibers-batch-map proc 20 lists)) -(define (fibers-batch-for-each proc batch-size . lists) - (let loop ((lists lists)) - (let ((batch - rest - (get-batch batch-size lists))) - (if (any null? batch) - *unspecified* - (let ((response-channels - (apply map - (lambda args - (defer-to-parallel-fiber - (lambda () - (apply proc args)))) - batch))) - (apply fetch-result-of-defered-thunks - response-channels) - (loop rest)))))) +(define (fibers-batch-for-each proc parallelism-limit . lists) + (apply fibers-batch-map + (lambda args + (apply proc args) + *unspecified*) + parallelism-limit + lists) + + *unspecified*) (define (fibers-for-each proc . lists) (apply fibers-batch-for-each proc 20 lists)) diff --git a/tests/parallelism.scm b/tests/parallelism.scm index 8249d68..0901b3c 100644 --- a/tests/parallelism.scm +++ b/tests/parallelism.scm @@ -3,6 +3,7 @@ (unit-test) (knots parallelism)) +;; Test fibers-map (run-fibers-for-tests (lambda () (assert-equal @@ -12,4 +13,34 @@ (* 2 i)) (iota 34)))))) +;; Test fibers-batch-map with a large batch size +(run-fibers-for-tests + (lambda () + (assert-equal + 1122 + (apply + (fibers-batch-map + (lambda (i) + (* 2 i)) + 100 + (iota 34)))))) + +;; Test fibers-map with vectors +(run-fibers-for-tests + (lambda () + (assert-equal + 1122 + (apply + (vector->list + (fibers-map + (lambda (i) + (* 2 i)) + (list->vector (iota 34)))))))) + +;; Test fibers-for-each +(run-fibers-for-tests + (lambda () + (fibers-for-each + (lambda (i) + (* 2 i)) + (iota 34)))) + (display "parallelism test finished successfully\n")