Skip to content

Commit

Permalink
fix: allow custom async matchers (#2707)
Browse files Browse the repository at this point in the history
  • Loading branch information
sheremet-va committed Jan 19, 2023
1 parent 8560758 commit b566912
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 18 deletions.
28 changes: 11 additions & 17 deletions packages/expect/src/jest-extend.ts
Expand Up @@ -17,9 +17,6 @@ import {
subsetEquality,
} from './jest-utils'

const isAsyncFunction = (fn: unknown) =>
typeof fn === 'function' && (fn as any)[Symbol.toStringTag] === 'AsyncFunction'

const getMatcherState = (assertion: Chai.AssertionStatic & Chai.Assertion, expect: Vi.ExpectStatic) => {
const obj = assertion._obj
const isNot = util.flag(assertion, 'negate') as boolean
Expand Down Expand Up @@ -56,30 +53,27 @@ class JestExtendError extends Error {
function JestExtendPlugin(expect: Vi.ExpectStatic, matchers: MatchersObject): ChaiPlugin {
return (c, utils) => {
Object.entries(matchers).forEach(([expectAssertionName, expectAssertion]) => {
function expectSyncWrapper(this: Chai.AssertionStatic & Chai.Assertion, ...args: any[]) {
function expectWrapper(this: Chai.AssertionStatic & Chai.Assertion, ...args: any[]) {
const { state, isNot, obj } = getMatcherState(this, expect)

// @ts-expect-error args wanting tuple
const { pass, message, actual, expected } = expectAssertion.call(state, obj, ...args) as SyncExpectationResult

if ((pass && isNot) || (!pass && !isNot))
throw new JestExtendError(message(), actual, expected)
}
const result = expectAssertion.call(state, obj, ...args)

async function expectAsyncWrapper(this: Chai.AssertionStatic & Chai.Assertion, ...args: any[]) {
const { state, isNot, obj } = getMatcherState(this, expect)
if (result && typeof result === 'object' && result instanceof Promise) {
return result.then(({ pass, message, actual, expected }) => {
if ((pass && isNot) || (!pass && !isNot))
throw new JestExtendError(message(), actual, expected)
})
}

// @ts-expect-error args wanting tuple
const { pass, message, actual, expected } = await expectAssertion.call(state, obj, ...args) as SyncExpectationResult
const { pass, message, actual, expected } = result

if ((pass && isNot) || (!pass && !isNot))
throw new JestExtendError(message(), actual, expected)
}

const expectAssertionWrapper = isAsyncFunction(expectAssertion) ? expectAsyncWrapper : expectSyncWrapper

utils.addMethod((globalThis as any)[JEST_MATCHERS_OBJECT].matchers, expectAssertionName, expectAssertionWrapper)
utils.addMethod(c.Assertion.prototype, expectAssertionName, expectAssertionWrapper)
utils.addMethod((globalThis as any)[JEST_MATCHERS_OBJECT].matchers, expectAssertionName, expectWrapper)
utils.addMethod(c.Assertion.prototype, expectAssertionName, expectWrapper)

class CustomMatcher extends AsymmetricMatcher<[unknown, ...unknown[]]> {
constructor(inverse = false, ...sample: [unknown, ...unknown[]]) {
Expand Down
28 changes: 27 additions & 1 deletion test/core/test/jest-expect.test.ts
Expand Up @@ -8,6 +8,9 @@ class TestError extends Error {}
// For expect.extend
interface CustomMatchers<R = unknown> {
toBeDividedBy(divisor: number): R
toBeTestedAsync(): Promise<R>
toBeTestedSync(): R
toBeTestedPromise(): R
}
declare global {
namespace Vi {
Expand Down Expand Up @@ -142,7 +145,7 @@ describe('jest-expect', () => {
expect(['Bob', 'Eve']).toEqual(expect.not.arrayContaining(['Steve']))
})

it('expect.extend', () => {
it('expect.extend', async () => {
expect.extend({
toBeDividedBy(received, divisor) {
const pass = received % divisor === 0
Expand All @@ -161,6 +164,24 @@ describe('jest-expect', () => {
}
}
},
async toBeTestedAsync() {
return {
pass: false,
message: () => 'toBeTestedAsync',
}
},
toBeTestedSync() {
return {
pass: false,
message: () => 'toBeTestedSync',
}
},
toBeTestedPromise() {
return Promise.resolve({
pass: false,
message: () => 'toBeTestedPromise',
})
},
})

expect(5).toBeDividedBy(5)
Expand All @@ -169,6 +190,11 @@ describe('jest-expect', () => {
one: expect.toBeDividedBy(1),
two: expect.not.toBeDividedBy(5),
})
expect(() => expect(2).toBeDividedBy(5)).toThrowError()

expect(() => expect(null).toBeTestedSync()).toThrowError('toBeTestedSync')
await expect(async () => await expect(null).toBeTestedAsync()).rejects.toThrowError('toBeTestedAsync')
await expect(async () => await expect(null).toBeTestedPromise()).rejects.toThrowError('toBeTestedPromise')
})

it('object', () => {
Expand Down

0 comments on commit b566912

Please sign in to comment.