Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(vitest): support vi.waitFor method #4113

Merged
merged 18 commits into from Sep 14, 2023
Merged
63 changes: 60 additions & 3 deletions docs/api/vi.md
Expand Up @@ -254,10 +254,10 @@ import { vi } from 'vitest'
```ts
// increment.test.js
import { vi } from 'vitest'

// axios is a default export from `__mocks__/axios.js`
import axios from 'axios'

// increment is a named export from `src/__mocks__/increment.js`
import { increment } from '../increment.js'

Expand Down Expand Up @@ -371,7 +371,7 @@ test('importing the next module imports mocked one', async () => {

```ts
import { vi } from 'vitest'

import { data } from './data.js' // Will not get reevaluated beforeEach test

beforeEach(() => {
Expand Down Expand Up @@ -717,3 +717,60 @@ unmockedIncrement(30) === 31
- **Type:** `() => Vitest`

When timers are run out, you may call this method to return mocked timers to its original implementations. All timers that were run before will not be restored.

### vi.waitFor

- **Type:** `function waitFor<T>(callback: WaitForCallback<T>, options?: number | WaitForOptions): Promise<T>`
- **Version**: Since Vitest 0.35.0
Dunqing marked this conversation as resolved.
Show resolved Hide resolved

Wait for the callback to execute successfully. If the callback throws an error or returns a rejected promise it will continue to wait until it succeeds or times out.

This is very useful for testing, see the following example.
Dunqing marked this conversation as resolved.
Show resolved Hide resolved

```ts
import { test, vi } from 'vitest'

let server = false
setTimeout(() => {
server = true
}, 100)

function checkServerStart() {
if (!server)
throw new Error('Server not started')

console.log('Server started')
}

test('Server started successfully', async () => {
const res = await vi.waitFor(checkServerStart, {
Dunqing marked this conversation as resolved.
Show resolved Hide resolved
timeout: 500, // default is 1000
interval: 20, // default is 50
})
expect(server).toBe(true)
})
```

It also works for asynchronous callbacks

```ts
import { test, vi } from 'vitest'

async function startServer() {
return new Promise((resolve) => {
setTimeout(() => {
server = true
resolve('Server started')
}, 100)
})
}

test('Server started successfully', async () => {
const server = await vi.waitFor(startServer, {
timeout: 500, // default is 1000
interval: 20, // default is 50
})
expect(server).toBe('Server started')
})
```

4 changes: 3 additions & 1 deletion packages/vitest/src/integrations/vi.ts
Expand Up @@ -9,6 +9,7 @@ import { resetModules, waitForImportsToResolve } from '../utils/modules'
import { FakeTimers } from './mock/timers'
import type { EnhancedSpy, MaybeMocked, MaybeMockedDeep, MaybePartiallyMocked, MaybePartiallyMockedDeep } from './spy'
import { fn, isMockFunction, spies, spyOn } from './spy'
import { waitFor } from './wait'

interface VitestUtils {
useFakeTimers(config?: FakeTimerInstallOpts): this
Expand All @@ -30,6 +31,7 @@ interface VitestUtils {

spyOn: typeof spyOn
fn: typeof fn
waitFor: typeof waitFor

/**
* Run the factory before imports are evaluated. You can return a value from the factory
Expand Down Expand Up @@ -292,7 +294,7 @@ function createVitest(): VitestUtils {

spyOn,
fn,

waitFor,
hoisted<T>(factory: () => T): T {
assertTypes(factory, '"vi.hoisted" factory', ['function'])
return factory()
Expand Down
93 changes: 93 additions & 0 deletions packages/vitest/src/integrations/wait.ts
@@ -0,0 +1,93 @@
import { getSafeTimers } from '@vitest/utils'

// The waitFor function was inspired by https://github.com/testing-library/web-testing-library/pull/2

export type WaitForCallback<T> = () => T | Promise<T>

export interface WaitForOptions {
/**
* @description Time in ms between each check callback
* @default 50ms
*/
interval?: number
/**
* @description Time in ms after which the throw a timeout error
* @default 1000ms
*/
timeout?: number
}

function copyStackTrace(target: Error, source: Error) {
if (source.stack !== undefined)
target.stack = source.stack.replace(source.message, target.message)
return target
}

export function waitFor<T>(callback: WaitForCallback<T>, options: number | WaitForOptions = {}) {
const { setTimeout, setInterval, clearTimeout, clearInterval } = getSafeTimers()
const { interval = 50, timeout = 1000 } = typeof options === 'number' ? { timeout: options } : options
const STACK_TRACE_ERROR = new Error('STACK_TRACE_ERROR')

return new Promise<T>((resolve, reject) => {
let lastError: unknown
let promiseStatus: 'idle' | 'pending' | 'resolved' | 'rejected' = 'idle'
let timeoutId: ReturnType<typeof setTimeout>
let intervalId: ReturnType<typeof setInterval>

const onResolve = (result: T) => {
if (timeoutId)
clearTimeout(timeoutId)
if (intervalId)
clearInterval(intervalId)

resolve(result)
}

const handleTimeout = () => {
let error = lastError
if (!error)
error = copyStackTrace(new Error('Timed out in waitFor!'), STACK_TRACE_ERROR)
sheremet-va marked this conversation as resolved.
Show resolved Hide resolved
Dunqing marked this conversation as resolved.
Show resolved Hide resolved

reject(error)
}

const checkCallback = () => {
if (promiseStatus === 'pending')
return
try {
const result = callback()
if (
result !== null
&& typeof result === 'object'
&& typeof (result as any).then === 'function'
) {
const thenable = result as PromiseLike<T>
promiseStatus = 'pending'
thenable.then(
(resolvedValue) => {
promiseStatus = 'resolved'
onResolve(resolvedValue)
},
(rejectedValue) => {
promiseStatus = 'rejected'
lastError = rejectedValue
},
)
}
else {
onResolve(result as T)
return true
}
}
catch (error) {
lastError = error
}
}

if (checkCallback() === true)
return

timeoutId = setTimeout(handleTimeout, timeout)
intervalId = setInterval(checkCallback, interval)
})
}
107 changes: 107 additions & 0 deletions test/core/test/wait.test.ts
@@ -0,0 +1,107 @@
import { getSafeTimers } from '@vitest/utils'
import { describe, expect, test, vi } from 'vitest'

describe('waitFor', () => {
describe('options', () => {
test('timeout', async () => {
expect(async () => {
await vi.waitFor(() => {
return new Promise((resolve) => {
setTimeout(() => {
resolve(true)
}, 1000)
})
}, 500)
}).rejects.toThrow('Timed out in waitFor!')
})

test('interval', async () => {
const callback = vi.fn(() => {
throw new Error('interval error')
})

await expect(
vi.waitFor(callback, {
timeout: 499,
interval: 100,
}),
).rejects.toMatchInlineSnapshot('[Error: interval error]')
Dunqing marked this conversation as resolved.
Show resolved Hide resolved

expect(callback).toHaveBeenCalledTimes(5)
})
})

test('basic', async () => {
let throwError = false
await vi.waitFor(() => {
if (!throwError) {
throwError = true
throw new Error('basic error')
}
})
expect(throwError).toBe(true)
})

test('async function', async () => {
let finished = false
setTimeout(() => {
finished = true
}, 500)
Dunqing marked this conversation as resolved.
Show resolved Hide resolved
await vi.waitFor(async () => {
if (finished)
return Promise.resolve(true)
else
return Promise.reject(new Error('async function error'))
})
})

test('stacktrace correctly', async () => {
const check = () => {
const _a = 1
// @ts-expect-error test
_a += 1
}
try {
await vi.waitFor(check)
}
catch (error) {
expect((error as Error).message).toMatchInlineSnapshot('"Assignment to constant variable."')
expect((error as Error).stack).toContain('/vitest/test/core/test/wait.test.ts:')
}
})

test('stacktrace point to waitFor', async () => {
const check = async () => {
return new Promise((resolve) => {
setTimeout(resolve, 600)
})
}
try {
await vi.waitFor(check, 500)
}
catch (error) {
expect(error).toMatchInlineSnapshot('[Error: Timed out in waitFor!]')
expect((error as Error).stack?.split('\n')[1]).toMatch(/waitFor\s*\(.*\)?/)
}
})

test('fakeTimer works', async () => {
vi.useFakeTimers()

const { setTimeout: safeSetTimeout } = getSafeTimers()

safeSetTimeout(() => {
vi.advanceTimersByTime(2000)
}, 500)

await vi.waitFor(() => {
return new Promise<void>((resolve) => {
setTimeout(() => {
resolve()
}, 1500)
})
}, 500)

vi.useRealTimers()
})
})