Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP: fix update attr #10278

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open

WIP: fix update attr #10278

wants to merge 2 commits into from

Conversation

Pang-GJ
Copy link
Contributor

@Pang-GJ Pang-GJ commented May 22, 2023

修复 issue:#10156
这个问题的产生是因为网络中的 upsample 算子的 height_scalewidth_scale 不正确。
原因是利用 shared graph 推理时,没有更新 attr,导致 scale 用的还是第一次推理的 scale。

// UPDATE_ATTR_MUTABLE(at_stride);
// UPDATE_ATTR_MUTABLE(at_list_stride);
// UPDATE_ATTR_MUTABLE(at_device);
// UPDATE_ATTR_MUTABLE(at_complex_double);
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里注释掉的参数更新有点问题 @strint ,目前正在排查是哪个更新导致的问题,目前遇到两种:

  1. 如果注释全取消,也就是全部更新会出现问题:
Traceback (most recent call last):
  File "/data/home/pangguojian/miniconda3/envs/oneflow-torch/lib/python3.9/runpy.py", line 197, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/data/home/pangguojian/miniconda3/envs/oneflow-torch/lib/python3.9/runpy.py", line 87, in _run_code
    exec(code, run_globals)
  File "/data/home/pangguojian/onediff/issue/reproduce_bug.py", line 16, in <module>
    img = pipe("A beautiful house", width=item, height=item,)
  File "/data/home/pangguojian/oneflow-master/python/oneflow/autograd/autograd_mode.py", line 154, in wrapper
    return func(*args, **kwargs)
  File "/data/home/pangguojian/.local/lib/python3.9/site-packages/onediff/pipeline_stable_diffusion_oneflow.py", line 703, in __call__
    noise_pred = unet_graph(latent_model_input, t, text_embeddings)
  File "/data/home/pangguojian/oneflow-master/python/oneflow/nn/graph/graph.py", line 269, in __call__
    return self._dynamic_input_graph_cache(*args, **kwargs)
  File "/data/home/pangguojian/oneflow-master/python/oneflow/nn/graph/cache.py", line 115, in __call__
    return graph(*args, **kwargs)
  File "/data/home/pangguojian/oneflow-master/python/oneflow/nn/graph/graph.py", line 272, in __call__
    self._compile(*args, **kwargs)
  File "/data/home/pangguojian/oneflow-master/python/oneflow/nn/graph/graph.py", line 842, in _compile
    return self._compile_from_shared(*args, **kwargs)
  File "/data/home/pangguojian/oneflow-master/python/oneflow/nn/graph/graph.py", line 964, in _compile_from_shared
    self._c_nn_graph.build_with_new_input_from_shared_graph(
oneflow._oneflow_internal.exception.RuntimeError: Error: Check failed: ((320,3,3,4) == (320,81,3,3))
  File "oneflow/core/graph/op_graph.h", line 102, in OpGraph
    Init(job)
  File "oneflow/core/graph/op_graph.cpp", line 179, in Init
    InferLogicalBlobDesc(job)
  File "oneflow/core/graph/op_graph.cpp", line 356, in InferLogicalBlobDesc
    TopoForEachNodeWithErrorCaptured([&](OpNode* op_node ... InferLogicalOutBlobDescsIf()); return Maybe<void>::Ok(); })
  File "oneflow/core/graph/graph.h", line 657, in TopoForEachNodeWithErrorCaptured
    Handler(cur_node)
  File "oneflow/core/graph/op_graph.cpp", line 392, in operator()
    op_node->mut_op()->InferLogicalOutBlobDescsIf()
  File "oneflow/core/operator/operator.cpp", line 329, in InferLogicalOutBlobDescsIf
    InferLogicalOutBlobDescs(BlobDesc4BnInOp, *JUST(GetOpParallelDesc()))
  File "oneflow/core/operator/user_op.cpp", line 773, in InferLogicalOutBlobDescs
    val_->logical_tensor_desc_infer_fn(&infer_ctx)
  File "oneflow/user/ops/conv_op.cpp", line 79, in InferTensorDesc4Conv
    CHECK_EQ_OR_RETURN(weight.shape(), Shape(weight_shape))
Error Type: oneflow.ErrorProto.check_failed_error
  File "oneflow/core/graph/op_graph.h", line 102, in operator()
Error Type: oneflow.ErrorProto.runtime_error
  1. 如果这样:
        UPDATE_ATTR_SET(at_int32);
        UPDATE_ATTR_SET(at_int64);
        UPDATE_ATTR_SET(at_bool);
        UPDATE_ATTR_SET(at_float);
        UPDATE_ATTR_SET(at_double);
        UPDATE_ATTR_SET(at_string);
        // UPDATE_ATTR_SET(at_data_type);
        // UPDATE_ATTR_SET(at_memory_format);

        UPDATE_ATTR_MUTABLE(at_shape);
        UPDATE_ATTR_MUTABLE(at_list_int32);
        UPDATE_ATTR_MUTABLE(at_list_int64);
        UPDATE_ATTR_MUTABLE(at_list_float);
        UPDATE_ATTR_MUTABLE(at_list_data_type);
        UPDATE_ATTR_MUTABLE(at_list_shape);
        UPDATE_ATTR_MUTABLE(at_list_string);
        UPDATE_ATTR_MUTABLE(at_stride);
        UPDATE_ATTR_MUTABLE(at_list_stride);
        UPDATE_ATTR_MUTABLE(at_device);
        UPDATE_ATTR_MUTABLE(at_complex_double);

会出现 OOM:

terminate called after throwing an instance of 'oneflow::RuntimeException'
  what():  Error: CUDA out of memory. Tried to allocate 114.0 MB
Error message from /data/home/pangguojian/oneflow-master/oneflow/core/vm/op_call_instruction_policy.cpp:278
        OpCallInstructionUtil::Compute(this, instruction->mut_stream(), true, false): empty:OpCall:s_compute

  File "oneflow/core/vm/op_call_instruction_policy.cpp", line 278, in Compute
    OpCallInstructionUtil::Compute(this, instruction->mut_stream(), true, false)
  File "oneflow/core/vm/op_call_instruction_policy.cpp", line 63, in Compute
    AllocateOutputBlobsMemory(op_call_instruction_policy, allocator, vm_stream)
  File "oneflow/core/vm/op_call_instruction_policy.cpp", line 115, in AllocateOutputBlobsMemory
    blob_object->TryAllocateBlobBodyMemory(allocator)
  File "oneflow/core/eager/eager_blob_object.cpp", line 105, in TryAllocateBlobBodyMemory
    allocator->Allocate(&dptr, required_body_bytes)
  File "oneflow/core/vm/bin_allocator.h", line 392, in Allocate
    AllocateBlockToExtendTotalMem(aligned_size)
  File "oneflow/core/vm/bin_allocator.h", line 305, in AllocateBlockToExtendTotalMem
    backend_->Allocate(&mem_ptr, final_allocate_bytes)
Error Type: oneflow.ErrorProto.out_of_memory_error
  File "oneflow/core/vm/op_call_instruction_policy.cpp", line 278, in operator()
Error Type: oneflow.ErrorProto.runtime_error
Related Python stack trace:
  File "/data/home/pangguojian/miniconda3/envs/oneflow-torch/lib/python3.9/threading.py", line 316, in wait
    gotit = waiter.acquire(True, timeout)
  File "/data/home/pangguojian/miniconda3/envs/oneflow-torch/lib/python3.9/threading.py", line 581, in wait
    signaled = self._cond.wait(timeout)
  File "/data/home/pangguojian/oneflow-master/python/oneflow/nn/graph/graph.py", line 60, in _compile_new

  File "/data/home/pangguojian/oneflow-master/python/oneflow/nn/graph/graph.py", line 840, in _compile
    return self._compile_new(*args, **kwargs)
  File "/data/home/pangguojian/oneflow-master/python/oneflow/nn/graph/cache.py", line 120, in _compile
    return graph._compile(*args, **kwargs)
  File "/data/home/pangguojian/oneflow-master/python/oneflow/nn/graph/graph.py", line 836, in _compile
    return self._dynamic_input_graph_cache._compile(*args, **kwargs)
  File "/data/home/pangguojian/.local/lib/python3.9/site-packages/onediff/pipeline_stable_diffusion_oneflow.py", line 684, in __call__
    unet_graph._compile(latent_model_input, t, text_embeddings)
  File "/data/home/pangguojian/oneflow-master/python/oneflow/autograd/autograd_mode.py", line 154, in wrapper
    return func(*args, **kwargs)
  File "/data/home/pangguojian/miniconda3/envs/oneflow-torch/lib/python3.9/site-packages/tqdm/contrib/concurrent.py", line 16, in _executor_map
    try:
  File "/data/home/pangguojian/miniconda3/envs/oneflow-torch/lib/python3.9/site-packages/tqdm/contrib/concurrent.py", line 94, in thread_map
    return _executor_map(ThreadPoolExecutor, fn, *iterables, **tqdm_kwargs)
  File "/data/home/pangguojian/.local/lib/python3.9/site-packages/huggingface_hub/_snapshot_download.py", line 239, in snapshot_download
    thread_map(
  File "/data/home/pangguojian/.local/lib/python3.9/site-packages/huggingface_hub/utils/_validators.py", line 120, in _inner_fn
    return fn(*args, **kwargs)
  File "/data/home/pangguojian/miniconda3/envs/oneflow-torch/lib/python3.9/site-packages/diffusers/pipelines/pipeline_utils.py", line 692, in from_pretrained
    cached_folder = snapshot_download(
  File "/data/home/pangguojian/miniconda3/envs/oneflow-torch/lib/python3.9/runpy.py", line 8, in _run_module_as_main
    """

Stack trace (most recent call last) in thread 2621583:
   Object "/data/home/pangguojian/oneflow-master/build/liboneflow.so", at 0x7f306d7faeb1, in
   Object "/data/home/pangguojian/oneflow-master/build/liboneflow.so", at 0x7f306d7faecd, in
   Object "/data/home/pangguojian/oneflow-master/build/liboneflow.so", at 0x7f306d7faf2e, in
   Object "/data/home/pangguojian/oneflow-master/build/liboneflow.so", at 0x7f306d7faffb, in
   Object "/data/home/pangguojian/oneflow-master/build/liboneflow.so", at 0x7f306d7fb0e0, in
   Object "/data/home/pangguojian/oneflow-master/build/liboneflow.so", at 0x7f306d7eabce, in
   Object "/data/home/pangguojian/oneflow-master/build/liboneflow.so", at 0x7f306d7e4bb9, in vm::ThreadCtx::TryReceiveAndRun()
   Object "/data/home/pangguojian/oneflow-master/build/liboneflow.so", at 0x7f306d7da06c, in vm::StreamPolicy::RunIf(vm::Instruction*) const
   Object "/data/home/pangguojian/oneflow-master/build/liboneflow.so", at 0x7f306d779cf2, in vm::EpStreamPolicyBase::Run(vm::Instruction*) const
   Object "/data/home/pangguojian/oneflow-master/build/liboneflow.so", at 0x7f306d77cc55, in vm::Instruction::Compute()
   Object "/data/home/pangguojian/oneflow-master/build/liboneflow.so", at 0x7f306d77d453, in vm::InstructionPolicy::ComputeIf(vm::Instruction*)
   Object "/data/home/pangguojian/oneflow-master/build/liboneflow.so", at 0x7f306d7853a2, in vm::OpCallInstructionPolicy::Compute(vm::Instruction*)
   Object "/data/home/pangguojian/oneflow-master/build/liboneflow.so", at 0x7f306d7851ba, in
   Object "/data/home/pangguojian/oneflow-master/python/oneflow/_oneflow_internal.cpython-39-x86_64-linux-gnu.so", at 0x7f30f52abd62, in details::Throw::operator=(Error&&)
   Object "/data/home/pangguojian/oneflow-master/build/liboneflow.so", at 0x7f306587aded, in ThrowError(std::shared_ptr<StackedError> const&)

Aborted (Signal sent by tkill() 2619772 1016)
Aborted (core dumped)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

只更新 at_doubleat_shape 没问题

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

好的。你先提个 pr,只加上 at_double 的更新。

其它 attr 的更新,我们弄清楚原因了再提。

@github-actions
Copy link
Contributor

View latest API docs preview at: https://staging.oneflow.info/docs/Oneflow-Inc/oneflow/pr/10278/

@github-actions
Copy link
Contributor

Speed stats:
GPU Name: NVIDIA GeForce RTX 3090 

❌ OneFlow resnet50 time: 43.1ms (= 4314.1ms / 100, input_shape=[16, 3, 224, 224])
PyTorch resnet50 time: 57.5ms (= 5754.7ms / 100, input_shape=[16, 3, 224, 224])
✔️ Relative speed: 1.33 (= 57.5ms / 43.1ms)

OneFlow resnet50 time: 25.8ms (= 2580.4ms / 100, input_shape=[8, 3, 224, 224])
PyTorch resnet50 time: 37.4ms (= 3740.2ms / 100, input_shape=[8, 3, 224, 224])
✔️ Relative speed: 1.45 (= 37.4ms / 25.8ms)

OneFlow resnet50 time: 18.3ms (= 3652.0ms / 200, input_shape=[4, 3, 224, 224])
PyTorch resnet50 time: 35.5ms (= 7092.9ms / 200, input_shape=[4, 3, 224, 224])
✔️ Relative speed: 1.94 (= 35.5ms / 18.3ms)

OneFlow resnet50 time: 16.8ms (= 3356.7ms / 200, input_shape=[2, 3, 224, 224])
PyTorch resnet50 time: 33.5ms (= 6693.2ms / 200, input_shape=[2, 3, 224, 224])
✔️ Relative speed: 1.99 (= 33.5ms / 16.8ms)

OneFlow resnet50 time: 16.1ms (= 3222.9ms / 200, input_shape=[1, 3, 224, 224])
PyTorch resnet50 time: 28.5ms (= 5701.8ms / 200, input_shape=[1, 3, 224, 224])
✔️ Relative speed: 1.77 (= 28.5ms / 16.1ms)

OneFlow swin dataloader time: 0.201s (= 40.283s / 200, num_workers=1)
PyTorch swin dataloader time: 0.128s (= 25.594s / 200, num_workers=1)
Relative speed: 0.635 (= 0.128s / 0.201s)

OneFlow swin dataloader time: 0.056s (= 11.232s / 200, num_workers=4)
PyTorch swin dataloader time: 0.032s (= 6.498s / 200, num_workers=4)
Relative speed: 0.579 (= 0.032s / 0.056s)

OneFlow swin dataloader time: 0.033s (= 6.630s / 200, num_workers=8)
PyTorch swin dataloader time: 0.017s (= 3.339s / 200, num_workers=8)
Relative speed: 0.504 (= 0.017s / 0.033s)

❌ OneFlow resnet50 time: 48.5ms (= 4848.9ms / 100, input_shape=[16, 3, 224, 224], ddp, world size=2)
PyTorch resnet50 time: 64.4ms (= 6442.0ms / 100, input_shape=[16, 3, 224, 224], ddp, world size=2)
✔️ Relative speed: 1.33 (= 64.4ms / 48.5ms)

OneFlow resnet50 time: 37.0ms (= 3698.0ms / 100, input_shape=[8, 3, 224, 224], ddp, world size=2)
PyTorch resnet50 time: 45.5ms (= 4547.6ms / 100, input_shape=[8, 3, 224, 224], ddp, world size=2)
✔️ Relative speed: 1.23 (= 45.5ms / 37.0ms)

OneFlow resnet50 time: 28.2ms (= 5634.8ms / 200, input_shape=[4, 3, 224, 224], ddp, world size=2)
PyTorch resnet50 time: 39.1ms (= 7815.8ms / 200, input_shape=[4, 3, 224, 224], ddp, world size=2)
✔️ Relative speed: 1.39 (= 39.1ms / 28.2ms)

OneFlow resnet50 time: 25.3ms (= 5066.6ms / 200, input_shape=[2, 3, 224, 224], ddp, world size=2)
PyTorch resnet50 time: 38.5ms (= 7702.2ms / 200, input_shape=[2, 3, 224, 224], ddp, world size=2)
✔️ Relative speed: 1.52 (= 38.5ms / 25.3ms)

OneFlow resnet50 time: 25.3ms (= 5052.4ms / 200, input_shape=[1, 3, 224, 224], ddp, world size=2)
PyTorch resnet50 time: 36.2ms (= 7230.4ms / 200, input_shape=[1, 3, 224, 224], ddp, world size=2)
✔️ Relative speed: 1.43 (= 36.2ms / 25.3ms)

@github-actions
Copy link
Contributor

View latest API docs preview at: https://staging.oneflow.info/docs/Oneflow-Inc/oneflow/pr/10278/

@github-actions
Copy link
Contributor

Speed stats:
GPU Name: NVIDIA GeForce RTX 3090 

❌ OneFlow resnet50 time: 43.2ms (= 4317.7ms / 100, input_shape=[16, 3, 224, 224])
PyTorch resnet50 time: 57.7ms (= 5769.7ms / 100, input_shape=[16, 3, 224, 224])
✔️ Relative speed: 1.34 (= 57.7ms / 43.2ms)

OneFlow resnet50 time: 25.8ms (= 2582.0ms / 100, input_shape=[8, 3, 224, 224])
PyTorch resnet50 time: 37.7ms (= 3774.0ms / 100, input_shape=[8, 3, 224, 224])
✔️ Relative speed: 1.46 (= 37.7ms / 25.8ms)

OneFlow resnet50 time: 18.5ms (= 3695.3ms / 200, input_shape=[4, 3, 224, 224])
PyTorch resnet50 time: 35.8ms (= 7157.9ms / 200, input_shape=[4, 3, 224, 224])
✔️ Relative speed: 1.94 (= 35.8ms / 18.5ms)

OneFlow resnet50 time: 18.2ms (= 3643.5ms / 200, input_shape=[2, 3, 224, 224])
PyTorch resnet50 time: 31.8ms (= 6363.3ms / 200, input_shape=[2, 3, 224, 224])
✔️ Relative speed: 1.75 (= 31.8ms / 18.2ms)

OneFlow resnet50 time: 16.7ms (= 3344.6ms / 200, input_shape=[1, 3, 224, 224])
PyTorch resnet50 time: 31.7ms (= 6342.9ms / 200, input_shape=[1, 3, 224, 224])
✔️ Relative speed: 1.90 (= 31.7ms / 16.7ms)

OneFlow swin dataloader time: 0.200s (= 40.099s / 200, num_workers=1)
PyTorch swin dataloader time: 0.129s (= 25.704s / 200, num_workers=1)
Relative speed: 0.641 (= 0.129s / 0.200s)

OneFlow swin dataloader time: 0.056s (= 11.103s / 200, num_workers=4)
PyTorch swin dataloader time: 0.033s (= 6.623s / 200, num_workers=4)
Relative speed: 0.597 (= 0.033s / 0.056s)

OneFlow swin dataloader time: 0.030s (= 6.037s / 200, num_workers=8)
PyTorch swin dataloader time: 0.017s (= 3.345s / 200, num_workers=8)
Relative speed: 0.554 (= 0.017s / 0.030s)

❌ OneFlow resnet50 time: 48.5ms (= 4852.6ms / 100, input_shape=[16, 3, 224, 224], ddp, world size=2)
PyTorch resnet50 time: 65.8ms (= 6583.6ms / 100, input_shape=[16, 3, 224, 224], ddp, world size=2)
✔️ Relative speed: 1.36 (= 65.8ms / 48.5ms)

OneFlow resnet50 time: 36.6ms (= 3656.1ms / 100, input_shape=[8, 3, 224, 224], ddp, world size=2)
PyTorch resnet50 time: 46.6ms (= 4663.8ms / 100, input_shape=[8, 3, 224, 224], ddp, world size=2)
✔️ Relative speed: 1.28 (= 46.6ms / 36.6ms)

OneFlow resnet50 time: 28.7ms (= 5731.1ms / 200, input_shape=[4, 3, 224, 224], ddp, world size=2)
PyTorch resnet50 time: 38.8ms (= 7757.0ms / 200, input_shape=[4, 3, 224, 224], ddp, world size=2)
✔️ Relative speed: 1.35 (= 38.8ms / 28.7ms)

OneFlow resnet50 time: 25.4ms (= 5088.6ms / 200, input_shape=[2, 3, 224, 224], ddp, world size=2)
PyTorch resnet50 time: 39.1ms (= 7825.8ms / 200, input_shape=[2, 3, 224, 224], ddp, world size=2)
✔️ Relative speed: 1.54 (= 39.1ms / 25.4ms)

OneFlow resnet50 time: 24.6ms (= 4921.9ms / 200, input_shape=[1, 3, 224, 224], ddp, world size=2)
PyTorch resnet50 time: 36.2ms (= 7249.0ms / 200, input_shape=[1, 3, 224, 224], ddp, world size=2)
✔️ Relative speed: 1.47 (= 36.2ms / 24.6ms)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants