diff --git a/src/baseWrapper.ts b/src/baseWrapper.ts index 2bce96558..42eeed23d 100644 --- a/src/baseWrapper.ts +++ b/src/baseWrapper.ts @@ -1,4 +1,4 @@ -import { isNotNullOrUndefined, textContent } from './utils' +import { textContent } from './utils' import type { TriggerOptions } from './createDomEvent' import { ComponentInternalInstance, @@ -41,6 +41,21 @@ export default abstract class BaseWrapper this.wrapperElement = element } + protected findAllDOMElements(selector: string): Element[] { + const elementRootNodes = this.getRootNodes().filter(isElement) + if (elementRootNodes.length === 0) return [] + + const result: Element[] = [ + ...elementRootNodes.filter((node) => node.matches(selector)) + ] + + elementRootNodes.forEach((rootNode) => { + result.push(...Array.from(rootNode.querySelectorAll(selector))) + }) + + return result + } + find( selector: K ): DOMWrapper @@ -65,24 +80,9 @@ export default abstract class BaseWrapper } } - const elementRootNodes = this.getRootNodes().filter( - (node): node is Element => node instanceof Element - ) - if (elementRootNodes.length === 0) { - return createWrapperError('DOMWrapper') - } - const matchingRootNode = elementRootNodes.find((node) => - node.matches(selector) - ) - if (matchingRootNode) { - return createDOMWrapper(matchingRootNode) - } - - const result = elementRootNodes - .map((node) => node.querySelector(selector)) - .filter(isNotNullOrUndefined) - if (result.length > 0) { - return createDOMWrapper(result[0]) + const elements = this.findAllDOMElements(selector) + if (elements.length > 0) { + return createDOMWrapper(elements[0]) } return createWrapperError('DOMWrapper') @@ -96,20 +96,7 @@ export default abstract class BaseWrapper ): DOMWrapper[] findAll(selector: string): DOMWrapper[] findAll(selector: string): DOMWrapper[] { - if (!isElement(this.element)) { - return [] - } - - const result = this.element.matches(selector) - ? [createDOMWrapper(this.element)] - : [] - - return [ - ...result, - ...Array.from(this.element.querySelectorAll(selector)).map((x) => - createDOMWrapper(x) - ) - ] + return this.findAllDOMElements(selector).map(createDOMWrapper) } // searching by string without specifying component results in WrapperLike object diff --git a/tests/components/MultipleRootRender.vue b/tests/components/MultipleRootRender.vue new file mode 100644 index 000000000..e99735f26 --- /dev/null +++ b/tests/components/MultipleRootRender.vue @@ -0,0 +1,14 @@ + + + diff --git a/tests/find.spec.ts b/tests/find.spec.ts index a296ab438..414b401f5 100644 --- a/tests/find.spec.ts +++ b/tests/find.spec.ts @@ -2,6 +2,7 @@ import { defineComponent, h, nextTick, Fragment } from 'vue' import { mount, VueWrapper } from '../src' import SuspenseComponent from './components/Suspense.vue' +import MultipleRootRender from './components/MultipleRootRender.vue' describe('find', () => { it('find using single root node', () => { @@ -156,6 +157,11 @@ describe('find', () => { const etc = wrapper.findComponent({ name: 'EmptyTestComponent' }) expect(etc.find('p').exists()).toBe(false) }) + + it('finds root node with SFC render function', () => { + const wrapper = mount(MultipleRootRender) + expect(wrapper.find('a').exists()).toBe(true) + }) }) describe('findAll', () => { @@ -335,5 +341,10 @@ describe('findAll', () => { .exists() ).toBe(true) }) + + it('finds all root nodes with SFC render function', () => { + const wrapper = mount(MultipleRootRender) + expect(wrapper.findAll('a')).toHaveLength(3) + }) }) })