Skip to content

Commit

Permalink
Merge pull request #195 from HALFpipe/dev/fix/graph
Browse files Browse the repository at this point in the history
Fix workflow caching
  • Loading branch information
HippocampusGirl committed Sep 6, 2021
2 parents 534fefa + 2c75c04 commit 072b12f
Show file tree
Hide file tree
Showing 7 changed files with 37 additions and 51 deletions.
4 changes: 2 additions & 2 deletions halfpipe/cli/run.py
Expand Up @@ -71,14 +71,14 @@ def run_stage_workflow(opts):

def run_stage_run(opts):
if opts.graphs is None:
from ..io import loadpicklelzma
from ..io.file.pickle import load_pickle_lzma

assert (
opts.graphs_file is not None
), "Missing required --graphs-file input for step run"

graphs_file = resolve(opts.graphs_file, opts.fs_root)
graphs = loadpicklelzma(graphs_file)
graphs = load_pickle_lzma(graphs_file)

if not isinstance(graphs, OrderedDict):
raise RuntimeError(
Expand Down
4 changes: 2 additions & 2 deletions halfpipe/cluster.py
Expand Up @@ -7,7 +7,7 @@
from math import ceil
from collections import OrderedDict

from .io import make_cachefilepath
from .io.file.pickle import _make_cache_file_path
from .utils import logger, inflect_engine as p
from .workflow.execgraph import filter_subject_graphs

Expand Down Expand Up @@ -117,7 +117,7 @@ def create_example_script(workdir, graphs: OrderedDict, opts):
n_chunks = len(subject_graphs)
assert n_chunks > 0

graphs_file = make_cachefilepath("graphs", uuid)
graphs_file = _make_cache_file_path("graphs", uuid)

n_cpus = 2
nipype_max_mem_gb = max(node.mem_gb for graph in graphs.values() for node in graph.nodes)
Expand Down
11 changes: 0 additions & 11 deletions halfpipe/io/__init__.py
Expand Up @@ -5,11 +5,6 @@
from .file import (
DictListFile,
AdaptiveLock,
loadpicklelzma,
dumppicklelzma,
make_cachefilepath,
cacheobj,
uncacheobj,
)

from .parse import (
Expand All @@ -30,16 +25,10 @@
__all__ = [
"DictListFile",
"AdaptiveLock",
"IndexedFile",
"parse_condition_file",
"parse_design",
"loadspreadsheet",
"loadmatrix",
"loadpicklelzma",
"dumppicklelzma",
"make_cachefilepath",
"cacheobj",
"uncacheobj",
"MetadataLoader",
"SidecarMetadataLoader",
"slice_timing_str",
Expand Down
7 changes: 0 additions & 7 deletions halfpipe/io/file/__init__.py
Expand Up @@ -6,14 +6,7 @@

from .lock import AdaptiveLock

from .pickle import loadpicklelzma, dumppicklelzma, make_cachefilepath, cacheobj, uncacheobj

__all__ = [
"DictListFile",
"AdaptiveLock",
"loadpicklelzma",
"dumppicklelzma",
"make_cachefilepath",
"cacheobj",
"uncacheobj",
]
43 changes: 23 additions & 20 deletions halfpipe/io/file/pickle.py
Expand Up @@ -2,57 +2,60 @@
# emacs: -*- mode: python; py-indent-offset: 4; indent-tabs-mode: nil -*-
# vi: set ft=python sts=4 ts=4 sw=4 et:

from typing import Optional

import lzma
import pickle
from uuid import UUID
from traits.trait_errors import TraitError
from pathlib import Path

from ...utils import logger


def loadpicklelzma(filepath):
def load_pickle_lzma(filepath):
try:
with lzma.open(filepath, "rb") as fptr:
return pickle.load(fptr)
except (lzma.LZMAError, TraitError, EOFError) as e:
logger.error(f'Error while reading "{filepath}"', exc_info=e)


def dumppicklelzma(filepath, obj):
def dump_pickle_lzma(filepath, obj):
try:
with lzma.open(filepath, "wb") as fptr:
pickle.dump(obj, fptr)
except lzma.LZMAError as e:
logger.error(f'Error while writing "{filepath}"', exc_info=e)


def uncacheobj(workdir, typestr, uuid, typedisplaystr=None):
if typedisplaystr is None:
typedisplaystr = typestr
path = Path(workdir) / make_cachefilepath(typestr, uuid)
def _make_cache_file_path(type_str: str, uuid: Optional[UUID]):
if uuid is not None:
uuidstr = str(uuid)[:8]
path = f"{type_str}.{uuidstr}.pickle.xz"
else:
path = f"{type_str}.pickle.xz"
return path


def uncache_obj(workdir, type_str: str, uuid: UUID, display_str: str = None):
if display_str is None:
display_str = type_str
path = Path(workdir) / _make_cache_file_path(type_str, uuid)
if path.exists():
obj = loadpicklelzma(path)
obj = load_pickle_lzma(path)
if uuid is not None and hasattr(obj, "uuid"):
objuuid = getattr(obj, "uuid")
if objuuid is None or objuuid != uuid:
return
logger.info(f"Cached {typedisplaystr} from {path}")
logger.info(f"Cached {display_str} from {path}")
return obj


def make_cachefilepath(typestr, uuid):
if uuid is not None:
uuidstr = str(uuid)[:8]
path = f"{typestr}.{uuidstr}.pickle.xz"
else:
path = f"{typestr}.pickle.xz"
return path


def cacheobj(workdir, typestr, obj, uuid=None):
def cache_obj(workdir, typestr, obj, uuid=None):
if uuid is None:
uuid = getattr(obj, "uuid", None)
path = Path(workdir) / make_cachefilepath(typestr, uuid)
path = Path(workdir) / _make_cache_file_path(typestr, uuid)
if path.exists():
logger.warning(f"Overwrite {path}")
dumppicklelzma(path, obj)
dump_pickle_lzma(path, obj)
6 changes: 3 additions & 3 deletions halfpipe/workflow/base.py
Expand Up @@ -18,7 +18,7 @@
from .convert import convert_all
from .constants import constants
from ..io.index import Database, BidsDatabase
from ..io.file import cacheobj, uncacheobj
from ..io.file.pickle import cache_obj, uncache_obj
from ..model.spec import loadspec
from ..utils import logger, deepcopyfactory
from .. import __version__
Expand Down Expand Up @@ -46,7 +46,7 @@ def init_workflow(workdir):
# uuid depends on the spec file, the files found and the version of the program
uuid = uuid5(spec.uuid, database.sha1 + __version__)

workflow = uncacheobj(workdir, ".workflow", uuid)
workflow = uncache_obj(workdir, ".workflow", uuid, display_str="workflow")
if workflow is not None:
return workflow

Expand Down Expand Up @@ -139,6 +139,6 @@ def init_workflow(workdir):
node.run_without_submitting = False # run all nodes in multiproc

logger.info(f"Finished workflow {uuidstr}")
cacheobj(workdir, ".workflow", workflow)
cache_obj(workdir, ".workflow", workflow)

return workflow
13 changes: 7 additions & 6 deletions halfpipe/workflow/execgraph.py
Expand Up @@ -27,7 +27,8 @@
from ..utils import resolve
from ..utils.format import format_like_bids
from ..utils.multiprocessing import Pool
from ..io import DictListFile, cacheobj, uncacheobj
from ..io.file.dictlistfile import DictListFile
from ..io.file.pickle import cache_obj, uncache_obj
from ..resource import get as getresource
from .constants import constants

Expand Down Expand Up @@ -209,14 +210,14 @@ def prepare_graph(workflow, item):


def init_flat_graph(workflow, workdir) -> nx.DiGraph:
flat_graph = uncacheobj(workdir, ".flat_graph", workflow.uuid)
flat_graph = uncache_obj(workdir, ".flat_graph", workflow.uuid, display_str="flat graph")
if flat_graph is not None:
return flat_graph

workflow._generate_flatgraph()
flat_graph = workflow._graph

cacheobj(workdir, ".flat_graph", flat_graph, uuid=workflow.uuid)
cache_obj(workdir, ".flat_graph", flat_graph, uuid=workflow.uuid)
return flat_graph


Expand Down Expand Up @@ -249,7 +250,7 @@ def init_execgraph(

# create or load execgraph

graphs: Optional[OrderedDictT[str, IdentifiableDiGraph]] = uncacheobj(workdir, "graphs", uuid)
graphs: Optional[OrderedDictT[str, IdentifiableDiGraph]] = uncache_obj(workdir, "graphs", uuid)
if graphs is not None:
return graphs

Expand Down Expand Up @@ -281,7 +282,7 @@ def init_execgraph(

with Pool() as pool:
graphs = OrderedDict(
pool.imap(partial(prepare_graph, workflow), graphs.items())
pool.map(partial(prepare_graph, workflow), graphs.items())
)

logger.info("Update input source at chunk boundaries")
Expand All @@ -292,6 +293,6 @@ def init_execgraph(
node.input_source.update(input_source_dict[node])

logger.info(f'Created graphs for workflow "{uuidstr}"')
cacheobj(workdir, "graphs", graphs, uuid=uuid)
cache_obj(workdir, "graphs", graphs, uuid=uuid)

return graphs

0 comments on commit 072b12f

Please sign in to comment.