Skip to content

Commit

Permalink
Implement more robust socket hook protocol
Browse files Browse the repository at this point in the history
  • Loading branch information
sdorra committed Nov 22, 2020
1 parent 73b2c4a commit abaa7b9
Show file tree
Hide file tree
Showing 6 changed files with 185 additions and 36 deletions.
Expand Up @@ -79,37 +79,47 @@ public void run() {
LOG.warn("failed to read hook request", e);
} finally {
LOG.trace("close client socket");
TransactionId.clear();
close();
}
}

private void handleHookRequest(InputStream input, OutputStream output) throws IOException {
Request request = Sockets.read(input, Request.class);
Request request = Sockets.receive(input, Request.class);
TransactionId.set(request.getTransactionId());
Response response = handleHookRequest(request);
Sockets.send(output, response);
}

private Response handleHookRequest(Request request) {
LOG.trace("process {} hook for node {}", request.getType(), request.getNode());
TransactionId.set(request.getTransactionId());

HgHookContextProvider context = hookContextProviderFactory.create(request.getRepositoryId(), request.getNode());
if (!environment.isAcceptAble(request.getChallenge())) {
LOG.warn("received hook with invalid challenge: {}", request.getChallenge());
return error("invalid hook challenge");
}

try {
if (!environment.isAcceptAble(request.getChallenge())) {
LOG.warn("received hook with invalid challenge: {}", request.getChallenge());
return error("invalid hook challenge");
}

authenticate(request);

return fireHook(request);
} catch (AuthenticationException ex) {
LOG.warn("hook authentication failed", ex);
return error("hook authentication failed");
}
}

@Nonnull
private Response fireHook(Request request) {
HgHookContextProvider context = hookContextProviderFactory.create(request.getRepositoryId(), request.getNode());

try {
environment.setPending(request.getType() == RepositoryHookType.PRE_RECEIVE);

hookEventFacade.handle(request.getRepositoryId()).fireHookEvent(request.getType(), context);

return new Response(context.getHgMessageProvider().getMessages(), false);
} catch (AuthenticationException ex) {
LOG.warn("hook authentication failed", ex);
return error("hook authentication failed");

} catch (NotFoundException ex) {
LOG.warn("could not find repository with id {}", request.getRepositoryId(), ex);
return error("repository not found");
Expand All @@ -121,7 +131,6 @@ private Response handleHookRequest(Request request) {
return error(context, "unknown error");
} finally {
environment.clearPendingState();
TransactionId.clear();
}
}

Expand Down
Expand Up @@ -25,33 +25,75 @@
package sonia.scm.repository.hooks;

import com.fasterxml.jackson.databind.ObjectMapper;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.io.ByteArrayOutputStream;
import java.io.EOFException;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;

class Sockets {

private static final Logger LOG = LoggerFactory.getLogger(Sockets.class);

private static final int READ_LIMIT = 8192;

private static final ObjectMapper objectMapper = new ObjectMapper();

private Sockets() {
}

static void send(OutputStream out, Object object) throws IOException {
byte[] bytes = objectMapper.writeValueAsBytes(object);
LOG.trace("send message length of {} to socket", bytes.length);
writeInt(out, bytes.length);
LOG.trace("send message to socket");
out.write(bytes);
out.write('\0');
LOG.trace("flush socket");
out.flush();
}

static <T> T read(InputStream in, Class<T> type) throws IOException {
ByteArrayOutputStream buffer = new ByteArrayOutputStream();
int c = in.read();
while (c != '\0') {
buffer.write(c);
c = in.read();
static <T> T receive(InputStream in, Class<T> type) throws IOException {
LOG.trace("read {} from socket", type);
int length = readInt(in);
LOG.trace("read message length of {} from socket", length);
if (length > READ_LIMIT) {
String message = String.format("received length of %d, which exceeds the limit of %d", length, READ_LIMIT);
throw new IOException(message);
}
return objectMapper.readValue(buffer.toByteArray(), type);
byte[] data = read(in, length);
LOG.trace("convert message to {}", type);
return objectMapper.readValue(data, type);
}

static void writeInt(OutputStream out, int value) throws IOException {
out.write((value >>> 24) & 0xFF);
out.write((value >>> 16) & 0xFF);
out.write((value >>> 8) & 0xFF);
out.write(value & 0xFF);
}

static int readInt(InputStream in) throws IOException {
int b1 = in.read();
int b2 = in.read();
int b3 = in.read();
int b4 = in.read();

if ((b1 | b2 | b3 | b4) < 0) {
throw new EOFException("failed to read int from socket");
}

return ((b1 << 24) + (b2 << 16) + (b3 << 8) + b4);
}

private static byte[] read(InputStream in, int length) throws IOException {
byte[] buffer = new byte[length];
int read = in.read(buffer);
if (read < length) {
throw new EOFException("failed to read bytes from socket");
}
return buffer;
}

}
Expand Up @@ -29,7 +29,7 @@
# changegroup.scm = python:scmhooks.callback
#

import os, sys, json, socket
import os, sys, json, socket, struct

# read environment
port = os.environ['SCM_HOOK_PORT']
Expand All @@ -54,17 +54,19 @@ def fire_hook(ui, repo, hooktype, node):
values = {'token': token, 'type': hooktype, 'repositoryId': repositoryId, 'transactionId': transactionId, 'challenge': challenge, 'node': node.decode('utf8') }

connection.connect(("127.0.0.1", int(port)))
connection.send(json.dumps(values).encode('utf-8'))
connection.sendall(b'\0')

received = []
byte = connection.recv(1)
while byte != b'\0':
received.append(byte)
byte = connection.recv(1)
data = json.dumps(values).encode('utf-8')
connection.send(struct.pack('>i', len(data)))
connection.sendall(data)

message = b''.join(received).decode('utf-8')
response = json.loads(message)
d = connection.recv(4, socket.MSG_WAITALL)
length = struct.unpack('>i', bytearray(d))[0]
if length > 8192:
ui.warn( b"scm-hook received message with exceeds the limit of 8192\n" )
return True

d = connection.recv(length, socket.MSG_WAITALL)
response = json.loads(d.decode("utf-8"))

abort = response['abort']
print_messages(ui, response['messages'])
Expand Down Expand Up @@ -94,7 +96,7 @@ def pre_hook(ui, repo, hooktype, node=None, source=None, pending=None, **kwargs)

# newer mercurial version
# we have to make in-memory changes visible to external process
# this does not happen automatically, because mercurial treat our hooks as internal hooks
# this does not happen automatically, because mercurial treat our hooks as internal hook
# see hook.py at mercurial sources _exthook
try:
if repo is not None:
Expand All @@ -103,7 +105,7 @@ def pre_hook(ui, repo, hooktype, node=None, source=None, pending=None, **kwargs)
if tr and not tr.writepending():
ui.warn(b"no pending write transaction found")
except AttributeError:
ui.debug(b"mercurial does not support currenttransation")
ui.debug(b"mercurial does not support currenttransaction")
# do nothing

return callback(ui, repo, "PRE_RECEIVE", node)
Expand Down
Expand Up @@ -291,14 +291,16 @@ private DefaultHookHandler.Request createRequest(RepositoryHookType type, String
private DefaultHookHandler.Response send(DefaultHookHandler.Request request) throws IOException {
ByteArrayOutputStream buffer = new ByteArrayOutputStream();
Sockets.send(buffer, request);

ByteArrayInputStream input = new ByteArrayInputStream(buffer.toByteArray());
when(socket.getInputStream()).thenReturn(input);

ByteArrayOutputStream output = new ByteArrayOutputStream();
when(socket.getOutputStream()).thenReturn(output);

handler.run();

return Sockets.read(new ByteArrayInputStream(output.toByteArray()), DefaultHookHandler.Response.class);
return Sockets.receive(new ByteArrayInputStream(output.toByteArray()), DefaultHookHandler.Response.class);
}

private static class TestingException extends ExceptionWithContext {
Expand Down
Expand Up @@ -82,7 +82,7 @@ private Response send(Request request) throws IOException {
OutputStream output = socket.getOutputStream()
) {
Sockets.send(output, request);
return Sockets.read(input, Response.class);
return Sockets.receive(input, Response.class);
} catch (IOException ex) {
throw new RuntimeException("failed", ex);
}
Expand All @@ -100,7 +100,7 @@ private HelloHandler(Socket socket) {
@Override
public void run() {
try (InputStream input = socket.getInputStream(); OutputStream output = socket.getOutputStream()) {
Request request = Sockets.read(input, Request.class);
Request request = Sockets.receive(input, Request.class);
Subject subject = SecurityUtils.getSubject();
Sockets.send(output, new Response("Hello " + request.getName(), subject.getPrincipal().toString()));
} catch (IOException ex) {
Expand Down
@@ -0,0 +1,94 @@
/*
* MIT License
*
* Copyright (c) 2020-present Cloudogu GmbH and Contributors
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in all
* copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*/

package sonia.scm.repository.hooks;

import lombok.AllArgsConstructor;
import lombok.Data;
import org.junit.jupiter.api.Test;

import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.IOException;

import static org.assertj.core.api.Assertions.assertThat;
import static org.junit.jupiter.api.Assertions.assertThrows;

class SocketsTest {

@Test
void shouldSendAndReceive() throws IOException {
ByteArrayOutputStream output = new ByteArrayOutputStream();
Sockets.send(output, new TestValue("awesome"));
ByteArrayInputStream input = new ByteArrayInputStream(output.toByteArray());
TestValue value = Sockets.receive(input, TestValue.class);
assertThat(value.value).isEqualTo("awesome");
}

@Test
void shouldFailWithTooFewBytesForLength() {
ByteArrayOutputStream output = new ByteArrayOutputStream();
output.write((512 >>> 24) & 0xFF);
output.write((512 >>> 16) & 0xFF);

ByteArrayInputStream input = new ByteArrayInputStream(output.toByteArray());
IOException ex = assertThrows(IOException.class, () -> Sockets.receive(input, TestValue.class));
assertThat(ex.getMessage()).contains("int");
}

@Test
void shouldFailWithTooFewBytesForData() {
ByteArrayOutputStream output = new ByteArrayOutputStream();
output.write((16 >>> 24) & 0xFF);
output.write((16 >>> 16) & 0xFF);
output.write((16 >>> 8) & 0xFF);
output.write(16 & 0xFF);

ByteArrayInputStream input = new ByteArrayInputStream(output.toByteArray());
IOException ex = assertThrows(IOException.class, () -> Sockets.receive(input, TestValue.class));
assertThat(ex.getMessage()).contains("bytes");
}

@Test
void shouldFailIfLimitIsExceeded() {
ByteArrayOutputStream output = new ByteArrayOutputStream();
output.write((9216 >>> 24) & 0xFF);
output.write((9216 >>> 16) & 0xFF);
output.write((9216 >>> 8) & 0xFF);
output.write(9216 & 0xFF);

ByteArrayInputStream input = new ByteArrayInputStream(output.toByteArray());
IOException ex = assertThrows(IOException.class, () -> Sockets.receive(input, TestValue.class));
assertThat(ex.getMessage()).contains("9216");
}

@Data
@AllArgsConstructor
public static class TestValue {

private String value;

}

}

0 comments on commit abaa7b9

Please sign in to comment.