Skip to content

Commit b566912

Browse files
authoredJan 19, 2023
fix: allow custom async matchers (#2707)
1 parent 8560758 commit b566912

File tree

2 files changed

+38
-18
lines changed

2 files changed

+38
-18
lines changed
 

‎packages/expect/src/jest-extend.ts

+11-17
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,6 @@ import {
1717
subsetEquality,
1818
} from './jest-utils'
1919

20-
const isAsyncFunction = (fn: unknown) =>
21-
typeof fn === 'function' && (fn as any)[Symbol.toStringTag] === 'AsyncFunction'
22-
2320
const getMatcherState = (assertion: Chai.AssertionStatic & Chai.Assertion, expect: Vi.ExpectStatic) => {
2421
const obj = assertion._obj
2522
const isNot = util.flag(assertion, 'negate') as boolean
@@ -56,30 +53,27 @@ class JestExtendError extends Error {
5653
function JestExtendPlugin(expect: Vi.ExpectStatic, matchers: MatchersObject): ChaiPlugin {
5754
return (c, utils) => {
5855
Object.entries(matchers).forEach(([expectAssertionName, expectAssertion]) => {
59-
function expectSyncWrapper(this: Chai.AssertionStatic & Chai.Assertion, ...args: any[]) {
56+
function expectWrapper(this: Chai.AssertionStatic & Chai.Assertion, ...args: any[]) {
6057
const { state, isNot, obj } = getMatcherState(this, expect)
6158

6259
// @ts-expect-error args wanting tuple
63-
const { pass, message, actual, expected } = expectAssertion.call(state, obj, ...args) as SyncExpectationResult
64-
65-
if ((pass && isNot) || (!pass && !isNot))
66-
throw new JestExtendError(message(), actual, expected)
67-
}
60+
const result = expectAssertion.call(state, obj, ...args)
6861

69-
async function expectAsyncWrapper(this: Chai.AssertionStatic & Chai.Assertion, ...args: any[]) {
70-
const { state, isNot, obj } = getMatcherState(this, expect)
62+
if (result && typeof result === 'object' && result instanceof Promise) {
63+
return result.then(({ pass, message, actual, expected }) => {
64+
if ((pass && isNot) || (!pass && !isNot))
65+
throw new JestExtendError(message(), actual, expected)
66+
})
67+
}
7168

72-
// @ts-expect-error args wanting tuple
73-
const { pass, message, actual, expected } = await expectAssertion.call(state, obj, ...args) as SyncExpectationResult
69+
const { pass, message, actual, expected } = result
7470

7571
if ((pass && isNot) || (!pass && !isNot))
7672
throw new JestExtendError(message(), actual, expected)
7773
}
7874

79-
const expectAssertionWrapper = isAsyncFunction(expectAssertion) ? expectAsyncWrapper : expectSyncWrapper
80-
81-
utils.addMethod((globalThis as any)[JEST_MATCHERS_OBJECT].matchers, expectAssertionName, expectAssertionWrapper)
82-
utils.addMethod(c.Assertion.prototype, expectAssertionName, expectAssertionWrapper)
75+
utils.addMethod((globalThis as any)[JEST_MATCHERS_OBJECT].matchers, expectAssertionName, expectWrapper)
76+
utils.addMethod(c.Assertion.prototype, expectAssertionName, expectWrapper)
8377

8478
class CustomMatcher extends AsymmetricMatcher<[unknown, ...unknown[]]> {
8579
constructor(inverse = false, ...sample: [unknown, ...unknown[]]) {

‎test/core/test/jest-expect.test.ts

+27-1
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,9 @@ class TestError extends Error {}
88
// For expect.extend
99
interface CustomMatchers<R = unknown> {
1010
toBeDividedBy(divisor: number): R
11+
toBeTestedAsync(): Promise<R>
12+
toBeTestedSync(): R
13+
toBeTestedPromise(): R
1114
}
1215
declare global {
1316
namespace Vi {
@@ -142,7 +145,7 @@ describe('jest-expect', () => {
142145
expect(['Bob', 'Eve']).toEqual(expect.not.arrayContaining(['Steve']))
143146
})
144147

145-
it('expect.extend', () => {
148+
it('expect.extend', async () => {
146149
expect.extend({
147150
toBeDividedBy(received, divisor) {
148151
const pass = received % divisor === 0
@@ -161,6 +164,24 @@ describe('jest-expect', () => {
161164
}
162165
}
163166
},
167+
async toBeTestedAsync() {
168+
return {
169+
pass: false,
170+
message: () => 'toBeTestedAsync',
171+
}
172+
},
173+
toBeTestedSync() {
174+
return {
175+
pass: false,
176+
message: () => 'toBeTestedSync',
177+
}
178+
},
179+
toBeTestedPromise() {
180+
return Promise.resolve({
181+
pass: false,
182+
message: () => 'toBeTestedPromise',
183+
})
184+
},
164185
})
165186

166187
expect(5).toBeDividedBy(5)
@@ -169,6 +190,11 @@ describe('jest-expect', () => {
169190
one: expect.toBeDividedBy(1),
170191
two: expect.not.toBeDividedBy(5),
171192
})
193+
expect(() => expect(2).toBeDividedBy(5)).toThrowError()
194+
195+
expect(() => expect(null).toBeTestedSync()).toThrowError('toBeTestedSync')
196+
await expect(async () => await expect(null).toBeTestedAsync()).rejects.toThrowError('toBeTestedAsync')
197+
await expect(async () => await expect(null).toBeTestedPromise()).rejects.toThrowError('toBeTestedPromise')
172198
})
173199

174200
it('object', () => {

0 commit comments

Comments
 (0)
Please sign in to comment.