@@ -17,9 +17,6 @@ import {
17
17
subsetEquality ,
18
18
} from './jest-utils'
19
19
20
- const isAsyncFunction = ( fn : unknown ) =>
21
- typeof fn === 'function' && ( fn as any ) [ Symbol . toStringTag ] === 'AsyncFunction'
22
-
23
20
const getMatcherState = ( assertion : Chai . AssertionStatic & Chai . Assertion , expect : Vi . ExpectStatic ) => {
24
21
const obj = assertion . _obj
25
22
const isNot = util . flag ( assertion , 'negate' ) as boolean
@@ -56,30 +53,27 @@ class JestExtendError extends Error {
56
53
function JestExtendPlugin ( expect : Vi . ExpectStatic , matchers : MatchersObject ) : ChaiPlugin {
57
54
return ( c , utils ) => {
58
55
Object . entries ( matchers ) . forEach ( ( [ expectAssertionName , expectAssertion ] ) => {
59
- function expectSyncWrapper ( this : Chai . AssertionStatic & Chai . Assertion , ...args : any [ ] ) {
56
+ function expectWrapper ( this : Chai . AssertionStatic & Chai . Assertion , ...args : any [ ] ) {
60
57
const { state, isNot, obj } = getMatcherState ( this , expect )
61
58
62
59
// @ts -expect-error args wanting tuple
63
- const { pass, message, actual, expected } = expectAssertion . call ( state , obj , ...args ) as SyncExpectationResult
64
-
65
- if ( ( pass && isNot ) || ( ! pass && ! isNot ) )
66
- throw new JestExtendError ( message ( ) , actual , expected )
67
- }
60
+ const result = expectAssertion . call ( state , obj , ...args )
68
61
69
- async function expectAsyncWrapper ( this : Chai . AssertionStatic & Chai . Assertion , ...args : any [ ] ) {
70
- const { state, isNot, obj } = getMatcherState ( this , expect )
62
+ if ( result && typeof result === 'object' && result instanceof Promise ) {
63
+ return result . then ( ( { pass, message, actual, expected } ) => {
64
+ if ( ( pass && isNot ) || ( ! pass && ! isNot ) )
65
+ throw new JestExtendError ( message ( ) , actual , expected )
66
+ } )
67
+ }
71
68
72
- // @ts -expect-error args wanting tuple
73
- const { pass, message, actual, expected } = await expectAssertion . call ( state , obj , ...args ) as SyncExpectationResult
69
+ const { pass, message, actual, expected } = result
74
70
75
71
if ( ( pass && isNot ) || ( ! pass && ! isNot ) )
76
72
throw new JestExtendError ( message ( ) , actual , expected )
77
73
}
78
74
79
- const expectAssertionWrapper = isAsyncFunction ( expectAssertion ) ? expectAsyncWrapper : expectSyncWrapper
80
-
81
- utils . addMethod ( ( globalThis as any ) [ JEST_MATCHERS_OBJECT ] . matchers , expectAssertionName , expectAssertionWrapper )
82
- utils . addMethod ( c . Assertion . prototype , expectAssertionName , expectAssertionWrapper )
75
+ utils . addMethod ( ( globalThis as any ) [ JEST_MATCHERS_OBJECT ] . matchers , expectAssertionName , expectWrapper )
76
+ utils . addMethod ( c . Assertion . prototype , expectAssertionName , expectWrapper )
83
77
84
78
class CustomMatcher extends AsymmetricMatcher < [ unknown , ...unknown [ ] ] > {
85
79
constructor ( inverse = false , ...sample : [ unknown , ...unknown [ ] ] ) {
0 commit comments