diff --git a/src/cmap/connection.ts b/src/cmap/connection.ts index 670fcd4994..881105548e 100644 --- a/src/cmap/connection.ts +++ b/src/cmap/connection.ts @@ -293,6 +293,15 @@ export class Connection extends TypedEventEmitter { this[kHello] = response; } + // Set the whether the message stream is for a monitoring connection. + set isMonitoringConnection(value: boolean) { + this[kMessageStream].isMonitoringConnection = value; + } + + get isMonitoringConnection(): boolean { + return this[kMessageStream].isMonitoringConnection; + } + get serviceId(): ObjectId | undefined { return this.hello?.serviceId; } diff --git a/src/cmap/message_stream.ts b/src/cmap/message_stream.ts index 55ed847e95..542e35cf73 100644 --- a/src/cmap/message_stream.ts +++ b/src/cmap/message_stream.ts @@ -53,6 +53,8 @@ export class MessageStream extends Duplex { maxBsonMessageSize: number; /** @internal */ [kBuffer]: BufferPool; + /** @internal */ + isMonitoringConnection = false; constructor(options: MessageStreamOptions = {}) { super(options); @@ -60,6 +62,10 @@ export class MessageStream extends Duplex { this[kBuffer] = new BufferPool(); } + get buffer(): BufferPool { + return this[kBuffer]; + } + override _write(chunk: Buffer, _: unknown, callback: Callback): void { this[kBuffer].append(chunk); processIncomingData(this, callback); @@ -162,15 +168,36 @@ function processIncomingData(stream: MessageStream, callback: Callback) opCode: message.readInt32LE(12) }; + const monitorHasAnotherHello = () => { + if (stream.isMonitoringConnection) { + // Can we read the next message size? + if (buffer.length >= 4) { + const sizeOfMessage = buffer.peek(4).readInt32LE(); + if (sizeOfMessage <= buffer.length) { + return true; + } + } + } + return false; + }; + let ResponseType = messageHeader.opCode === OP_MSG ? BinMsg : Response; if (messageHeader.opCode !== OP_COMPRESSED) { const messageBody = message.slice(MESSAGE_HEADER_SIZE); - stream.emit('message', new ResponseType(message, messageHeader, messageBody)); - if (buffer.length >= 4) { + // If we are a monitoring connection message stream and + // there is more in the buffer that can be read, skip processing since we + // want the last hello command response that is in the buffer. + if (monitorHasAnotherHello()) { processIncomingData(stream, callback); } else { - callback(); + stream.emit('message', new ResponseType(message, messageHeader, messageBody)); + + if (buffer.length >= 4) { + processIncomingData(stream, callback); + } else { + callback(); + } } return; @@ -198,12 +225,19 @@ function processIncomingData(stream: MessageStream, callback: Callback) return; } - stream.emit('message', new ResponseType(message, messageHeader, messageBody)); - - if (buffer.length >= 4) { + // If we are a monitoring connection message stream and + // there is more in the buffer that can be read, skip processing since we + // want the last hello command response that is in the buffer. + if (monitorHasAnotherHello()) { processIncomingData(stream, callback); } else { - callback(); + stream.emit('message', new ResponseType(message, messageHeader, messageBody)); + + if (buffer.length >= 4) { + processIncomingData(stream, callback); + } else { + callback(); + } } }); } diff --git a/src/sdam/monitor.ts b/src/sdam/monitor.ts index 92d626bd91..1cab4b5d16 100644 --- a/src/sdam/monitor.ts +++ b/src/sdam/monitor.ts @@ -88,6 +88,10 @@ export class Monitor extends TypedEventEmitter { [kMonitorId]?: InterruptibleAsyncInterval; [kRTTPinger]?: RTTPinger; + get connection(): Connection | undefined { + return this[kConnection]; + } + constructor(server: Server, options: MonitorOptions) { super(); @@ -310,6 +314,10 @@ function checkServer(monitor: Monitor, callback: Callback) { } if (conn) { + // Tell the connection that we are using the streaming protocol so that the + // connection's message stream will only read the last hello on the buffer. + conn.isMonitoringConnection = true; + if (isInCloseState(monitor)) { conn.destroy({ force: true }); return; diff --git a/test/tools/utils.ts b/test/tools/utils.ts index de5addfebc..95b61773d3 100644 --- a/test/tools/utils.ts +++ b/test/tools/utils.ts @@ -1,7 +1,10 @@ import { EJSON } from 'bson'; +import * as BSON from 'bson'; import { expect } from 'chai'; import { inspect, promisify } from 'util'; +import { OP_MSG } from '../../src/cmap/wire_protocol/constants'; +import { Document } from '../../src/index'; import { Logger } from '../../src/logger'; import { deprecateOptions, DeprecateOptionsConfig } from '../../src/utils'; import { runUnifiedSuite } from './unified-spec-runner/runner'; @@ -343,6 +346,24 @@ export class TestBuilder { } } +export function generateOpMsgBuffer(document: Document): Buffer { + const header = Buffer.alloc(4 * 4 + 4); + + const typeBuffer = Buffer.alloc(1); + typeBuffer[0] = 0; + + const docBuffer = BSON.serialize(document); + + const totalLength = header.length + typeBuffer.length + docBuffer.length; + + header.writeInt32LE(totalLength, 0); + header.writeInt32LE(0, 4); + header.writeInt32LE(0, 8); + header.writeInt32LE(OP_MSG, 12); + header.writeUInt32LE(0, 16); + return Buffer.concat([header, typeBuffer, docBuffer]); +} + export class UnifiedTestSuiteBuilder { private _description = 'Default Description'; private _schemaVersion = '1.0'; diff --git a/test/unit/cmap/message_stream.test.js b/test/unit/cmap/message_stream.test.js index ca8ab447f4..3158a9144a 100644 --- a/test/unit/cmap/message_stream.test.js +++ b/test/unit/cmap/message_stream.test.js @@ -1,10 +1,12 @@ 'use strict'; -const Readable = require('stream').Readable; -const Writable = require('stream').Writable; +const { on, once } = require('events'); +const { Readable, Writable } = require('stream'); + const { MessageStream } = require('../../../src/cmap/message_stream'); const { Msg } = require('../../../src/cmap/commands'); const expect = require('chai').expect; const { LEGACY_HELLO_COMMAND } = require('../../../src/constants'); +const { generateOpMsgBuffer } = require('../../tools/utils'); function bufferToStream(buffer) { const stream = new Readable(); @@ -18,117 +20,117 @@ function bufferToStream(buffer) { return stream; } -describe('Message Stream', function () { - describe('reading', function () { - [ - { - description: 'valid OP_REPLY', - data: Buffer.from( - '370000000100000001000000010000000000000000000000000000000000000001000000130000001069736d6173746572000100000000', - 'hex' - ), - documents: [{ [LEGACY_HELLO_COMMAND]: 1 }] - }, - { - description: 'valid multiple OP_REPLY', - expectedMessageCount: 4, - data: Buffer.from( - '370000000100000001000000010000000000000000000000000000000000000001000000130000001069736d6173746572000100000000' + - '370000000100000001000000010000000000000000000000000000000000000001000000130000001069736d6173746572000100000000' + - '370000000100000001000000010000000000000000000000000000000000000001000000130000001069736d6173746572000100000000' + - '370000000100000001000000010000000000000000000000000000000000000001000000130000001069736d6173746572000100000000', - 'hex' - ), - documents: [{ [LEGACY_HELLO_COMMAND]: 1 }] - }, - { - description: 'valid OP_REPLY (partial)', - data: [ - Buffer.from('37', 'hex'), - Buffer.from('0000', 'hex'), - Buffer.from( - '000100000001000000010000000000000000000000000000000000000001000000130000001069736d6173746572000100000000', - 'hex' - ) - ], - documents: [{ [LEGACY_HELLO_COMMAND]: 1 }] - }, - - { - description: 'valid OP_MSG', - data: Buffer.from( - '370000000100000000000000dd0700000000000000220000001069736d6173746572000100000002246462000600000061646d696e0000', - 'hex' - ), - documents: [{ $db: 'admin', [LEGACY_HELLO_COMMAND]: 1 }] - }, - { - description: 'valid multiple OP_MSG', - expectedMessageCount: 4, - data: Buffer.from( - '370000000100000000000000dd0700000000000000220000001069736d6173746572000100000002246462000600000061646d696e0000' + - '370000000100000000000000dd0700000000000000220000001069736d6173746572000100000002246462000600000061646d696e0000' + - '370000000100000000000000dd0700000000000000220000001069736d6173746572000100000002246462000600000061646d696e0000' + - '370000000100000000000000dd0700000000000000220000001069736d6173746572000100000002246462000600000061646d696e0000', - 'hex' - ), - documents: [{ $db: 'admin', [LEGACY_HELLO_COMMAND]: 1 }] - }, - - { - description: 'Invalid message size (negative)', - data: Buffer.from('ffffffff', 'hex'), - error: 'Invalid message size: -1' - }, - { - description: 'Invalid message size (exceeds maximum)', - data: Buffer.from('01000004', 'hex'), - error: 'Invalid message size: 67108865, max allowed: 67108864' - } - ].forEach(test => { - it(test.description, function (done) { - const error = test.error; - const expectedMessageCount = test.expectedMessageCount || 1; - const inputStream = bufferToStream(test.data); - const messageStream = new MessageStream(); +describe('MessageStream', function () { + context('when the stream is for a monitoring connection', function () { + const response = { isWritablePrimary: true }; + const lastResponse = { ok: 1 }; + let firstHello; + let secondHello; + let thirdHello; + let partial; + + beforeEach(function () { + firstHello = generateOpMsgBuffer(response); + secondHello = generateOpMsgBuffer(response); + thirdHello = generateOpMsgBuffer(lastResponse); + partial = Buffer.alloc(5); + partial.writeInt32LE(100, 0); + }); - let messageCount = 0; - messageStream.on('message', msg => { - messageCount++; - if (error) { - done(new Error(`expected error: ${error}`)); - return; - } + it('only reads the last message in the buffer', async function () { + const inputStream = bufferToStream(Buffer.concat([firstHello, secondHello, thirdHello])); + const messageStream = new MessageStream(); + messageStream.isMonitoringConnection = true; + + inputStream.pipe(messageStream); + const messages = await once(messageStream, 'message'); + const msg = messages[0]; + msg.parse(); + expect(msg).to.have.property('documents').that.deep.equals([lastResponse]); + // Make sure there is nothing left in the buffer. + expect(messageStream.buffer.length).to.equal(0); + }); - msg.parse(); + it('does not read partial messages', async function () { + const inputStream = bufferToStream( + Buffer.concat([firstHello, secondHello, thirdHello, partial]) + ); + const messageStream = new MessageStream(); + messageStream.isMonitoringConnection = true; + + inputStream.pipe(messageStream); + const messages = await once(messageStream, 'message'); + const msg = messages[0]; + msg.parse(); + expect(msg).to.have.property('documents').that.deep.equals([lastResponse]); + // Make sure the buffer wasn't read to the end. + expect(messageStream.buffer.length).to.equal(5); + }); + }); - if (test.documents) { - expect(msg).to.have.property('documents').that.deep.equals(test.documents); - } + context('when the stream is not for a monitoring connection', function () { + context('when the messages are valid', function () { + const response = { isWritablePrimary: true }; + let firstHello; + let secondHello; + let thirdHello; + let messageCount = 0; + + beforeEach(function () { + firstHello = generateOpMsgBuffer(response); + secondHello = generateOpMsgBuffer(response); + thirdHello = generateOpMsgBuffer(response); + }); - if (messageCount === expectedMessageCount) { - done(); - } - }); + it('reads all messages in the buffer', async function () { + const inputStream = bufferToStream(Buffer.concat([firstHello, secondHello, thirdHello])); + const messageStream = new MessageStream(); - messageStream.on('error', err => { - if (error == null) { - done(err); + inputStream.pipe(messageStream); + for await (const messages of on(messageStream, 'message')) { + messageCount++; + const msg = messages[0]; + msg.parse(); + expect(msg).to.have.property('documents').that.deep.equals([response]); + // Test will not complete until 3 messages processed. + if (messageCount === 3) { return; } + } + }); + }); - expect(err).to.have.property('message').that.equals(error); + context('when the messages are invalid', function () { + context('when the message size is negative', function () { + it('emits an error', async function () { + const inputStream = bufferToStream(Buffer.from('ffffffff', 'hex')); + const messageStream = new MessageStream(); - done(); + inputStream.pipe(messageStream); + const errors = await once(messageStream, 'error'); + const err = errors[0]; + expect(err).to.have.property('message').that.equals('Invalid message size: -1'); }); + }); - inputStream.pipe(messageStream); + context('when the message size exceeds the bson maximum', function () { + it('emits an error', async function () { + const inputStream = bufferToStream(Buffer.from('01000004', 'hex')); + const messageStream = new MessageStream(); + + inputStream.pipe(messageStream); + const errors = await once(messageStream, 'error'); + const err = errors[0]; + expect(err) + .to.have.property('message') + .that.equals('Invalid message size: 67108865, max allowed: 67108864'); + }); }); }); }); - describe('writing', function () { - it('should write a message to the stream', function (done) { + context('when writing to the message stream', function () { + it('pushes the message', function (done) { const readableStream = new Readable({ read() {} }); const writeableStream = new Writable({ write: (chunk, _, callback) => { diff --git a/test/unit/sdam/monitor.test.js b/test/unit/sdam/monitor.test.js index 2d7b786e1c..a161202dbe 100644 --- a/test/unit/sdam/monitor.test.js +++ b/test/unit/sdam/monitor.test.js @@ -115,7 +115,10 @@ describe('monitoring', function () { monitor = new Monitor(server, {}); monitor.on('serverHeartbeatFailed', () => done(new Error('unexpected heartbeat failure'))); - monitor.on('serverHeartbeatSucceeded', () => done()); + monitor.on('serverHeartbeatSucceeded', () => { + expect(monitor.connection.isMonitoringConnection).to.be.true; + done(); + }); monitor.connect(); });