Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(vitest): support vi.waitFor method #4113

Merged
merged 18 commits into from Sep 14, 2023
Merged
65 changes: 62 additions & 3 deletions docs/api/vi.md
Expand Up @@ -254,10 +254,10 @@ import { vi } from 'vitest'
```ts
// increment.test.js
import { vi } from 'vitest'

// axios is a default export from `__mocks__/axios.js`
import axios from 'axios'

// increment is a named export from `src/__mocks__/increment.js`
import { increment } from '../increment.js'

Expand Down Expand Up @@ -371,7 +371,7 @@ test('importing the next module imports mocked one', async () => {

```ts
import { vi } from 'vitest'

import { data } from './data.js' // Will not get reevaluated beforeEach test

beforeEach(() => {
Expand Down Expand Up @@ -717,3 +717,62 @@ unmockedIncrement(30) === 31
- **Type:** `() => Vitest`

When timers are run out, you may call this method to return mocked timers to its original implementations. All timers that were run before will not be restored.

### vi.waitFor

- **Type:** `function waitFor<T>(callback: WaitForCallback<T>, options?: number | WaitForOptions): Promise<T>`
- **Version**: Since Vitest 0.35.0
Dunqing marked this conversation as resolved.
Show resolved Hide resolved

Wait for the callback to execute successfully. If the callback throws an error or returns a rejected promise it will continue to wait until it succeeds or times out.

This is very useful when you need to wait for some asynchronous action to complete, for example, when you start a server and need to wait for it to start.

```ts
import { test, vi } from 'vitest'

test('Server started successfully', async () => {
let server = false

setTimeout(() => {
server = true
}, 100)

function checkServerStart() {
if (!server)
throw new Error('Server not started')

console.log('Server started')
}

const res = await vi.waitFor(checkServerStart, {
Dunqing marked this conversation as resolved.
Show resolved Hide resolved
timeout: 500, // default is 1000
interval: 20, // default is 50
})
expect(server).toBe(true)
})
```

It also works for asynchronous callbacks

```ts
import { test, vi } from 'vitest'

test('Server started successfully', async () => {
async function startServer() {
return new Promise((resolve) => {
setTimeout(() => {
server = true
resolve('Server started')
}, 100)
})
}

const server = await vi.waitFor(startServer, {
timeout: 500, // default is 1000
interval: 20, // default is 50
})
expect(server).toBe('Server started')
})
```

If `vi.useFakeTimers` is used, `vi.waitFor` automatically calls `vi.AdvanceTimersByTime(interval)` in every check callback.
4 changes: 3 additions & 1 deletion packages/vitest/src/integrations/vi.ts
Expand Up @@ -9,6 +9,7 @@ import { resetModules, waitForImportsToResolve } from '../utils/modules'
import { FakeTimers } from './mock/timers'
import type { EnhancedSpy, MaybeMocked, MaybeMockedDeep, MaybePartiallyMocked, MaybePartiallyMockedDeep } from './spy'
import { fn, isMockFunction, spies, spyOn } from './spy'
import { waitFor } from './wait'

interface VitestUtils {
useFakeTimers(config?: FakeTimerInstallOpts): this
Expand All @@ -30,6 +31,7 @@ interface VitestUtils {

spyOn: typeof spyOn
fn: typeof fn
waitFor: typeof waitFor

/**
* Run the factory before imports are evaluated. You can return a value from the factory
Expand Down Expand Up @@ -292,7 +294,7 @@ function createVitest(): VitestUtils {

spyOn,
fn,

waitFor,
hoisted<T>(factory: () => T): T {
assertTypes(factory, '"vi.hoisted" factory', ['function'])
return factory()
Expand Down
98 changes: 98 additions & 0 deletions packages/vitest/src/integrations/wait.ts
@@ -0,0 +1,98 @@
import { getSafeTimers } from '@vitest/utils'
import { vi } from './vi'

// The waitFor function was inspired by https://github.com/testing-library/web-testing-library/pull/2

export type WaitForCallback<T> = () => T | Promise<T>

export interface WaitForOptions {
/**
* @description Time in ms between each check callback
* @default 50ms
*/
interval?: number
/**
* @description Time in ms after which the throw a timeout error
* @default 1000ms
*/
timeout?: number
}

function copyStackTrace(target: Error, source: Error) {
if (source.stack !== undefined)
target.stack = source.stack.replace(source.message, target.message)
return target
}

export function waitFor<T>(callback: WaitForCallback<T>, options: number | WaitForOptions = {}) {
const { setTimeout, setInterval, clearTimeout, clearInterval } = getSafeTimers()
const { interval = 50, timeout = 1000 } = typeof options === 'number' ? { timeout: options } : options
const STACK_TRACE_ERROR = new Error('STACK_TRACE_ERROR')

return new Promise<T>((resolve, reject) => {
let lastError: unknown
let promiseStatus: 'idle' | 'pending' | 'resolved' | 'rejected' = 'idle'
let timeoutId: ReturnType<typeof setTimeout>
let intervalId: ReturnType<typeof setInterval>

const onResolve = (result: T) => {
if (timeoutId)
clearTimeout(timeoutId)
if (intervalId)
clearInterval(intervalId)

resolve(result)
}

const handleTimeout = () => {
let error = lastError
if (!error)
error = copyStackTrace(new Error('Timed out in waitFor!'), STACK_TRACE_ERROR)
sheremet-va marked this conversation as resolved.
Show resolved Hide resolved
Dunqing marked this conversation as resolved.
Show resolved Hide resolved

reject(error)
}

const checkCallback = () => {
// use fake timers
if (globalThis.setTimeout !== setTimeout)
Dunqing marked this conversation as resolved.
Show resolved Hide resolved
vi.advanceTimersByTime(interval)

if (promiseStatus === 'pending')
return
try {
const result = callback()
if (
result !== null
&& typeof result === 'object'
&& typeof (result as any).then === 'function'
) {
const thenable = result as PromiseLike<T>
promiseStatus = 'pending'
thenable.then(
(resolvedValue) => {
promiseStatus = 'resolved'
onResolve(resolvedValue)
},
(rejectedValue) => {
promiseStatus = 'rejected'
lastError = rejectedValue
},
)
}
else {
onResolve(result as T)
return true
}
}
catch (error) {
lastError = error
}
}

if (checkCallback() === true)
return

timeoutId = setTimeout(handleTimeout, timeout)
intervalId = setInterval(checkCallback, interval)
})
}
104 changes: 104 additions & 0 deletions test/core/test/wait.test.ts
@@ -0,0 +1,104 @@
import { describe, expect, test, vi } from 'vitest'

describe('waitFor', () => {
describe('options', () => {
test('timeout', async () => {
expect(async () => {
await vi.waitFor(() => {
return new Promise((resolve) => {
setTimeout(() => {
resolve(true)
}, 100)
})
}, 50)
}).rejects.toThrow('Timed out in waitFor!')
})

test('interval', async () => {
const callback = vi.fn(() => {
throw new Error('interval error')
})

await expect(
vi.waitFor(callback, {
timeout: 60,
interval: 30,
}),
).rejects.toThrowErrorMatchingInlineSnapshot('"interval error"')

expect(callback).toHaveBeenCalledTimes(2)
})
})

test('basic', async () => {
let throwError = false
await vi.waitFor(() => {
if (!throwError) {
throwError = true
throw new Error('basic error')
}
})
expect(throwError).toBe(true)
})

test('async function', async () => {
let finished = false
setTimeout(() => {
finished = true
}, 50)
await vi.waitFor(async () => {
if (finished)
return Promise.resolve(true)
else
return Promise.reject(new Error('async function error'))
})
})

test('stacktrace correctly', async () => {
const check = () => {
const _a = 1
// @ts-expect-error test
_a += 1
}
try {
await vi.waitFor(check, 100)
}
catch (error) {
expect((error as Error).message).toMatchInlineSnapshot('"Assignment to constant variable."')
expect.soft((error as Error).stack).toMatch(/at check/)
}
})

test('stacktrace point to waitFor', async () => {
const check = async () => {
return new Promise((resolve) => {
setTimeout(resolve, 60)
})
}
try {
await vi.waitFor(check, 50)
}
catch (error) {
expect(error).toMatchInlineSnapshot('[Error: Timed out in waitFor!]')
expect((error as Error).stack?.split('\n')[1]).toMatch(/waitFor\s*\(.*\)?/)
}
})

test('fakeTimer works', async () => {
vi.useFakeTimers()

setTimeout(() => {
vi.advanceTimersByTime(200)
}, 50)

await vi.waitFor(() => {
return new Promise<void>((resolve) => {
setTimeout(() => {
resolve()
}, 150)
})
}, 200)

vi.useRealTimers()
})
})