Skip to content

Commit

Permalink
Use unique thread ID for each partial render to access Context (faceb…
Browse files Browse the repository at this point in the history
…ook#14182)

* BUG: ReactPartialRenderer / New Context polutes mutable global state

The new context API stores the provided values on the shared context instance. When used in a synchronous context, this is not an issue. However when used in an concurrent context this can cause a "push provider" from one react render to have an effect on an unrelated concurrent react render.

I've encountered this bug in production when using renderToNodeStream, which asks ReactPartialRenderer for bytes up to a high water mark before yielding. If two Node Streams are created and read from in parallel, the state of one can polute the other.

I wrote a failing test to illustrate the conditions under which this happens.

I'm also concerned that the experimental concurrent/async React rendering on the client could suffer from the same issue.

* Use unique thread ID for each partial render to access Context

This first adds an allocator that keeps track of a unique ThreadID index
for each currently executing partial renderer. IDs are not just growing
but are reused as streams are destroyed.

This ensures that IDs are kept nice and compact.

This lets us use an "array" for each Context object to store the current
values. The look up for these are fast because they're just looking up
an offset in a tightly packed "array".

I don't use an actual Array object to store the values. Instead, I rely
on that VMs (notably V8) treat storage of numeric index property access
as a separate "elements" allocation.

This lets us avoid an extra indirection.

However, we must ensure that these arrays are not holey to preserve this
feature.

To do that I store the _threadCount on each context (effectively it takes
the place of the .length property on an array).

This lets us first validate that the context has enough slots before we
access the slot. If not, we fill in the slots with the default value.
  • Loading branch information
sebmarkbage committed Nov 9, 2018
1 parent 1a6ab1e commit 961eb65
Show file tree
Hide file tree
Showing 12 changed files with 336 additions and 74 deletions.
Expand Up @@ -348,5 +348,96 @@ describe('ReactDOMServerIntegration', () => {
await render(<App />, 1);
},
);

it('does not pollute parallel node streams', () => {
const LoggedInUser = React.createContext();

const AppWithUser = user => (
<LoggedInUser.Provider value={user}>
<header>
<LoggedInUser.Consumer>{whoAmI => whoAmI}</LoggedInUser.Consumer>
</header>
<footer>
<LoggedInUser.Consumer>{whoAmI => whoAmI}</LoggedInUser.Consumer>
</footer>
</LoggedInUser.Provider>
);

const streamAmy = ReactDOMServer.renderToNodeStream(
AppWithUser('Amy'),
).setEncoding('utf8');
const streamBob = ReactDOMServer.renderToNodeStream(
AppWithUser('Bob'),
).setEncoding('utf8');

// Testing by filling the buffer using internal _read() with a small
// number of bytes to avoid a test case which needs to align to a
// highWaterMark boundary of 2^14 chars.
streamAmy._read(20);
streamBob._read(20);
streamAmy._read(20);
streamBob._read(20);

expect(streamAmy.read()).toBe('<header>Amy</header><footer>Amy</footer>');
expect(streamBob.read()).toBe('<header>Bob</header><footer>Bob</footer>');
});

it('does not pollute parallel node streams when many are used', () => {
const CurrentIndex = React.createContext();

const NthRender = index => (
<CurrentIndex.Provider value={index}>
<header>
<CurrentIndex.Consumer>{idx => idx}</CurrentIndex.Consumer>
</header>
<footer>
<CurrentIndex.Consumer>{idx => idx}</CurrentIndex.Consumer>
</footer>
</CurrentIndex.Provider>
);

let streams = [];

// Test with more than 32 streams to test that growing the thread count
// works properly.
let streamCount = 34;

for (let i = 0; i < streamCount; i++) {
streams[i] = ReactDOMServer.renderToNodeStream(
NthRender(i % 2 === 0 ? 'Expected to be recreated' : i),
).setEncoding('utf8');
}

// Testing by filling the buffer using internal _read() with a small
// number of bytes to avoid a test case which needs to align to a
// highWaterMark boundary of 2^14 chars.
for (let i = 0; i < streamCount; i++) {
streams[i]._read(20);
}

// Early destroy every other stream
for (let i = 0; i < streamCount; i += 2) {
streams[i].destroy();
}

// Recreate those same streams.
for (let i = 0; i < streamCount; i += 2) {
streams[i] = ReactDOMServer.renderToNodeStream(
NthRender(i),
).setEncoding('utf8');
}

// Read a bit from all streams again.
for (let i = 0; i < streamCount; i++) {
streams[i]._read(20);
}

// Assert that all stream rendered the expected output.
for (let i = 0; i < streamCount; i++) {
expect(streams[i].read()).toBe(
'<header>' + i + '</header><footer>' + i + '</footer>',
);
}
});
});
});
6 changes: 5 additions & 1 deletion packages/react-dom/src/server/ReactDOMNodeStreamRenderer.js
Expand Up @@ -18,11 +18,15 @@ class ReactMarkupReadableStream extends Readable {
this.partialRenderer = new ReactPartialRenderer(element, makeStaticMarkup);
}

_destroy() {
this.partialRenderer.destroy();
}

_read(size) {
try {
this.push(this.partialRenderer.read(size));
} catch (err) {
this.emit('error', err);
this.destroy(err);
}
}
}
Expand Down
16 changes: 12 additions & 4 deletions packages/react-dom/src/server/ReactDOMStringRenderer.js
Expand Up @@ -14,8 +14,12 @@ import ReactPartialRenderer from './ReactPartialRenderer';
*/
export function renderToString(element) {
const renderer = new ReactPartialRenderer(element, false);
const markup = renderer.read(Infinity);
return markup;
try {
const markup = renderer.read(Infinity);
return markup;
} finally {
renderer.destroy();
}
}

/**
Expand All @@ -25,6 +29,10 @@ export function renderToString(element) {
*/
export function renderToStaticMarkup(element) {
const renderer = new ReactPartialRenderer(element, true);
const markup = renderer.read(Infinity);
return markup;
try {
const markup = renderer.read(Infinity);
return markup;
} finally {
renderer.destroy();
}
}
102 changes: 35 additions & 67 deletions packages/react-dom/src/server/ReactPartialRenderer.js
Expand Up @@ -7,6 +7,7 @@
* @flow
*/

import type {ThreadID} from './ReactThreadIDAllocator';
import type {ReactElement} from 'shared/ReactElementType';
import type {ReactProvider, ReactContext} from 'shared/ReactTypes';

Expand All @@ -16,7 +17,6 @@ import getComponentName from 'shared/getComponentName';
import lowPriorityWarning from 'shared/lowPriorityWarning';
import warning from 'shared/warning';
import warningWithoutStack from 'shared/warningWithoutStack';
import checkPropTypes from 'prop-types/checkPropTypes';
import describeComponentFrame from 'shared/describeComponentFrame';
import ReactSharedInternals from 'shared/ReactSharedInternals';
import {
Expand All @@ -39,6 +39,12 @@ import {
REACT_MEMO_TYPE,
} from 'shared/ReactSymbols';

import {
emptyObject,
processContext,
validateContextBounds,
} from './ReactPartialRendererContext';
import {allocThreadID, freeThreadID} from './ReactThreadIDAllocator';
import {
createMarkupForCustomAttribute,
createMarkupForProperty,
Expand All @@ -50,6 +56,8 @@ import {
finishHooks,
Dispatcher,
DispatcherWithoutHooks,
currentThreadID,
setCurrentThreadID,
} from './ReactPartialRendererHooks';
import {
Namespaces,
Expand Down Expand Up @@ -176,7 +184,6 @@ const didWarnAboutBadClass = {};
const didWarnAboutDeprecatedWillMount = {};
const didWarnAboutUndefinedDerivedState = {};
const didWarnAboutUninitializedState = {};
const didWarnAboutInvalidateContextType = {};
const valuePropNames = ['value', 'defaultValue'];
const newlineEatingTags = {
listing: true,
Expand Down Expand Up @@ -324,65 +331,6 @@ function flattenOptionChildren(children: mixed): ?string {
return content;
}

const emptyObject = {};
if (__DEV__) {
Object.freeze(emptyObject);
}

function maskContext(type, context) {
const contextTypes = type.contextTypes;
if (!contextTypes) {
return emptyObject;
}
const maskedContext = {};
for (const contextName in contextTypes) {
maskedContext[contextName] = context[contextName];
}
return maskedContext;
}

function checkContextTypes(typeSpecs, values, location: string) {
if (__DEV__) {
checkPropTypes(
typeSpecs,
values,
location,
'Component',
getCurrentServerStackImpl,
);
}
}

function processContext(type, context) {
const contextType = type.contextType;
if (typeof contextType === 'object' && contextType !== null) {
if (__DEV__) {
if (contextType.$$typeof !== REACT_CONTEXT_TYPE) {
let name = getComponentName(type) || 'Component';
if (!didWarnAboutInvalidateContextType[name]) {
didWarnAboutInvalidateContextType[type] = true;
warningWithoutStack(
false,
'%s defines an invalid contextType. ' +
'contextType should point to the Context object returned by React.createContext(). ' +
'Did you accidentally pass the Context.Provider instead?',
name,
);
}
}
}
return contextType._currentValue;
} else {
const maskedContext = maskContext(type, context);
if (__DEV__) {
if (type.contextTypes) {
checkContextTypes(type.contextTypes, maskedContext, 'context');
}
}
return maskedContext;
}
}

const hasOwnProperty = Object.prototype.hasOwnProperty;
const STYLE = 'style';
const RESERVED_PROPS = {
Expand Down Expand Up @@ -453,6 +401,7 @@ function validateRenderResult(child, type) {
function resolve(
child: mixed,
context: Object,
threadID: ThreadID,
): {|
child: mixed,
context: Object,
Expand All @@ -472,7 +421,7 @@ function resolve(

// Extra closure so queue and replace can be captured properly
function processChild(element, Component) {
let publicContext = processContext(Component, context);
let publicContext = processContext(Component, context, threadID);

let queue = [];
let replace = false;
Expand Down Expand Up @@ -718,6 +667,7 @@ type FrameDev = Frame & {
};

class ReactDOMServerRenderer {
threadID: ThreadID;
stack: Array<Frame>;
exhausted: boolean;
// TODO: type this more strictly:
Expand Down Expand Up @@ -747,6 +697,7 @@ class ReactDOMServerRenderer {
if (__DEV__) {
((topFrame: any): FrameDev).debugElementStack = [];
}
this.threadID = allocThreadID();
this.stack = [topFrame];
this.exhausted = false;
this.currentSelectValue = null;
Expand All @@ -763,6 +714,13 @@ class ReactDOMServerRenderer {
}
}

destroy() {
if (!this.exhausted) {
this.exhausted = true;
freeThreadID(this.threadID);
}
}

/**
* Note: We use just two stacks regardless of how many context providers you have.
* Providers are always popped in the reverse order to how they were pushed
Expand All @@ -776,7 +734,9 @@ class ReactDOMServerRenderer {
pushProvider<T>(provider: ReactProvider<T>): void {
const index = ++this.contextIndex;
const context: ReactContext<any> = provider.type._context;
const previousValue = context._currentValue;
const threadID = this.threadID;
validateContextBounds(context, threadID);
const previousValue = context[threadID];

// Remember which value to restore this context to on our way up.
this.contextStack[index] = context;
Expand All @@ -787,7 +747,7 @@ class ReactDOMServerRenderer {
}

// Mutate the current value.
context._currentValue = provider.props.value;
context[threadID] = provider.props.value;
}

popProvider<T>(provider: ReactProvider<T>): void {
Expand All @@ -813,14 +773,18 @@ class ReactDOMServerRenderer {
this.contextIndex--;

// Restore to the previous value we stored as we were walking down.
context._currentValue = previousValue;
// We've already verified that this context has been expanded to accommodate
// this thread id, so we don't need to do it again.
context[this.threadID] = previousValue;
}

read(bytes: number): string | null {
if (this.exhausted) {
return null;
}

const prevThreadID = currentThreadID;
setCurrentThreadID(this.threadID);
const prevDispatcher = ReactCurrentOwner.currentDispatcher;
if (enableHooks) {
ReactCurrentOwner.currentDispatcher = Dispatcher;
Expand All @@ -835,6 +799,7 @@ class ReactDOMServerRenderer {
while (out[0].length < bytes) {
if (this.stack.length === 0) {
this.exhausted = true;
freeThreadID(this.threadID);
break;
}
const frame: Frame = this.stack[this.stack.length - 1];
Expand Down Expand Up @@ -906,6 +871,7 @@ class ReactDOMServerRenderer {
return out[0];
} finally {
ReactCurrentOwner.currentDispatcher = prevDispatcher;
setCurrentThreadID(prevThreadID);
}
}

Expand All @@ -929,7 +895,7 @@ class ReactDOMServerRenderer {
return escapeTextForBrowser(text);
} else {
let nextChild;
({child: nextChild, context} = resolve(child, context));
({child: nextChild, context} = resolve(child, context, this.threadID));
if (nextChild === null || nextChild === false) {
return '';
} else if (!React.isValidElement(nextChild)) {
Expand Down Expand Up @@ -1136,7 +1102,9 @@ class ReactDOMServerRenderer {
}
}
const nextProps: any = (nextChild: any).props;
const nextValue = reactContext._currentValue;
const threadID = this.threadID;
validateContextBounds(reactContext, threadID);
const nextValue = reactContext[threadID];

const nextChildren = toArray(nextProps.children(nextValue));
const frame: Frame = {
Expand Down

0 comments on commit 961eb65

Please sign in to comment.