Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

raise TypeError(f"Please convert {rhs} with const first") #36

Open
Wanger-SJTU opened this issue Jun 9, 2023 · 6 comments
Open

raise TypeError(f"Please convert {rhs} with const first") #36

Wanger-SJTU opened this issue Jun 9, 2023 · 6 comments

Comments

@Wanger-SJTU
Copy link

while loading model ,get

text_config_dict` is provided which will be used to initialize `CLIPTextConfig`. The value `text_config["id2label"]` will be overriden.
Traceback (most recent call last):
  File "/srv/workspace/anaconda3/envs/web_sd/lib/python3.9/site-packages/torch/_dynamo/output_graph.py", line 670, in call_user_compiler
    compiled_fn = compiler_fn(gm, self.fake_example_inputs())
  File "/srv/workspace/anaconda3/envs/web_sd/lib/python3.9/site-packages/torch/_dynamo/debug_utils.py", line 1055, in debug_wrapper
    compiled_gm = compiler_fn(gm, example_inputs)
  File "/srv/workspace/framework/tvm/python/tvm/relax/frontend/torch/dynamo.py", line 151, in _capture
    mod_ = from_fx(
  File "/srv/workspace/framework/tvm/python/tvm/relax/frontend/torch/fx_translator.py", line 1421, in from_fx
    return TorchFXImporter().from_fx(
  File "/srv/workspace/framework/tvm/python/tvm/relax/frontend/torch/fx_translator.py", line 1308, in from_fx
    self.env[node] = self.convert_map[func_name](node)
  File "/srv/workspace/framework/tvm/python/tvm/relax/frontend/torch/fx_translator.py", line 180, in _add
    return lhs + rhs
  File "/srv/workspace/framework/tvm/python/tvm/relax/expr.py", line 155, in __add__
    return _binary_op_helper(self, other, _op_ffi_api.add)  # type: ignore
  File "/srv/workspace/framework/tvm/python/tvm/relax/expr.py", line 104, in _binary_op_helper
    raise TypeError(f"Please convert {rhs} with `const` first")
TypeError: Please convert 1 with `const` first
@yohuna77777
Copy link

I have same problem.

@nineis7
Copy link

nineis7 commented Jul 7, 2023

You can use diffusers 0.15.0 version. pip install diffusers==0.15.0

@haili-tian
Copy link

haili-tian commented Aug 16, 2023

You can use diffusers 0.15.0 version. pip install diffusers==0.15.0

replace diffusers with version 0.15.0, issue still exist.

@haili-tian
Copy link

can fixed with this patch

diff --git a/python/tvm/relax/frontend/torch/fx_translator.py b/python/tvm/relax/frontend/torch/fx_translator.py
index fde31af60..392ff2b39 100644
--- a/python/tvm/relax/frontend/torch/fx_translator.py
+++ b/python/tvm/relax/frontend/torch/fx_translator.py
@@ -177,6 +177,23 @@ class TorchFXImporter:
             return self._call_binary_op(
                 relax.op.add, relax.const(lhs, dtype=rhs.struct_info.dtype), rhs
             )
+        elif isinstance(lhs, int):
+            return self._call_binary_op(
+                relax.op.add, relax.const(lhs, dtype="int64"), rhs
+            )
+        elif isinstance(rhs, int):
+            return self._call_binary_op(
+                relax.op.add, lhs, relax.const(rhs, dtype="int64")
+            )
+        elif isinstance(lhs, float):
+            return self._call_binary_op(
+                relax.op.add, relax.const(lhs, dtype="float32"), rhs
+            )
+        elif isinstance(rhs, float):
+            return self._call_binary_op(
+                relax.op.add, lhs, relax.const(rhs, dtype="float32")
+            )
+
         return lhs + rhs
 
     def _max(self, node: fx.node.Node) -> relax.Expr:

@BillyGun27
Copy link

can fixed with this patch

diff --git a/python/tvm/relax/frontend/torch/fx_translator.py b/python/tvm/relax/frontend/torch/fx_translator.py
index fde31af60..392ff2b39 100644
--- a/python/tvm/relax/frontend/torch/fx_translator.py
+++ b/python/tvm/relax/frontend/torch/fx_translator.py
@@ -177,6 +177,23 @@ class TorchFXImporter:
             return self._call_binary_op(
                 relax.op.add, relax.const(lhs, dtype=rhs.struct_info.dtype), rhs
             )
+        elif isinstance(lhs, int):
+            return self._call_binary_op(
+                relax.op.add, relax.const(lhs, dtype="int64"), rhs
+            )
+        elif isinstance(rhs, int):
+            return self._call_binary_op(
+                relax.op.add, lhs, relax.const(rhs, dtype="int64")
+            )
+        elif isinstance(lhs, float):
+            return self._call_binary_op(
+                relax.op.add, relax.const(lhs, dtype="float32"), rhs
+            )
+        elif isinstance(rhs, float):
+            return self._call_binary_op(
+                relax.op.add, lhs, relax.const(rhs, dtype="float32")
+            )
+
         return lhs + rhs
 
     def _max(self, node: fx.node.Node) -> relax.Expr:

this code could fix the clip_to_text_embeddings(pipe) function, but when execute vae_to_image(pipe) encountered another error

AssertionError                            Traceback (most recent call last)
File /opt/anaconda/envs/venv-mlc/lib/python3.10/site-packages/torch/_dynamo/output_graph.py:670, in OutputGraph.call_user_compiler(self, gm)
    669 else:
--> 670     compiled_fn = compiler_fn(gm, self.fake_example_inputs())
    671 _step_logger()(logging.INFO, f"done compiler function {name}")

File /opt/anaconda/envs/venv-mlc/lib/python3.10/site-packages/torch/_dynamo/debug_utils.py:1055, in wrap_backend_debug.<locals>.debug_wrapper(gm, example_inputs, **kwargs)
   1054 else:
-> 1055     compiled_gm = compiler_fn(gm, example_inputs)
   1057 return compiled_gm

File /opt/anaconda/envs/venv-mlc/lib/python3.10/site-packages/tvm/relax/frontend/torch/dynamo.py:161, in dynamo_capture_subgraphs.<locals>._capture(graph_module, example_inputs)
    160 input_info = [(tuple(tensor.shape), str(tensor.dtype)) for tensor in example_inputs]
--> 161 mod_ = from_fx(
    162     graph_module,
    163     input_info,
    164     keep_params_as_input=keep_params_as_input,
    165     unwrap_unit_return_tuple=True,
    166 )
    167 new_name = f"subgraph_{len(mod.get_global_vars())}"

File /opt/anaconda/envs/venv-mlc/lib/python3.10/site-packages/tvm/relax/frontend/torch/fx_translator.py:1492, in from_fx(model, input_info, keep_params_as_input, unwrap_unit_return_tuple, no_bind_return_tuple)
   1404 """Convert a PyTorch FX GraphModule to a Relax program
   1405 
   1406 Parameters
   (...)
   1490 check the placeholder rows in the beginning of the tabular.
   1491 """
-> 1492 return TorchFXImporter().from_fx(
   1493     model, input_info, keep_params_as_input, unwrap_unit_return_tuple, no_bind_return_tuple
   1494 )

File /opt/anaconda/envs/venv-mlc/lib/python3.10/site-packages/tvm/relax/frontend/torch/fx_translator.py:1377, in TorchFXImporter.from_fx(self, model, input_info, keep_params_as_input, unwrap_unit_return_tuple, no_bind_return_tuple)
   1375 func_name = node.name.rstrip("0123456789_")
   1376 assert (
-> 1377     func_name in self.convert_map
   1378 ), f"Unsupported function type {func_name}"
   1379 self.env[node] = self.convert_map[func_name](node)

AssertionError: Unsupported function type conv2d

The above exception was the direct cause of the following exception:

BackendCompilerFailed                     Traceback (most recent call last)
Cell In[13], line 1
----> 1 vae = vae_to_image(pipe)

Cell In[10], line 22, in vae_to_image(pipe)
     19 vae_to_image = VAEModelWrapper(vae)
     21 z = torch.rand((1, 4, 64, 64), dtype=torch.float32)
---> 22 mod = dynamo_capture_subgraphs(
     23     vae_to_image.forward,
     24     z,
     25     keep_params_as_input=True,
     26 )
     27 assert len(mod.functions) == 1
     29 return tvm.IRModule({"vae": mod["subgraph_0"]})

File /opt/anaconda/envs/venv-mlc/lib/python3.10/site-packages/tvm/relax/frontend/torch/dynamo.py:175, in dynamo_capture_subgraphs(model, *params, **kwargs)
    172 compiled_model = torch.compile(model, backend=_capture)
    174 with torch.no_grad():
--> 175     compiled_model(*params, **kwargs)
    177 return mod

@jinqiua
Copy link

jinqiua commented Sep 8, 2023

can fixed with this patch

diff --git a/python/tvm/relax/frontend/torch/fx_translator.py b/python/tvm/relax/frontend/torch/fx_translator.py
index fde31af60..392ff2b39 100644
--- a/python/tvm/relax/frontend/torch/fx_translator.py
+++ b/python/tvm/relax/frontend/torch/fx_translator.py
@@ -177,6 +177,23 @@ class TorchFXImporter:
             return self._call_binary_op(
                 relax.op.add, relax.const(lhs, dtype=rhs.struct_info.dtype), rhs
             )
+        elif isinstance(lhs, int):
+            return self._call_binary_op(
+                relax.op.add, relax.const(lhs, dtype="int64"), rhs
+            )
+        elif isinstance(rhs, int):
+            return self._call_binary_op(
+                relax.op.add, lhs, relax.const(rhs, dtype="int64")
+            )
+        elif isinstance(lhs, float):
+            return self._call_binary_op(
+                relax.op.add, relax.const(lhs, dtype="float32"), rhs
+            )
+        elif isinstance(rhs, float):
+            return self._call_binary_op(
+                relax.op.add, lhs, relax.const(rhs, dtype="float32")
+            )
+
         return lhs + rhs
 
     def _max(self, node: fx.node.Node) -> relax.Expr:

this code could fix the clip_to_text_embeddings(pipe) function, but when execute vae_to_image(pipe) encountered another error

AssertionError                            Traceback (most recent call last)
File /opt/anaconda/envs/venv-mlc/lib/python3.10/site-packages/torch/_dynamo/output_graph.py:670, in OutputGraph.call_user_compiler(self, gm)
    669 else:
--> 670     compiled_fn = compiler_fn(gm, self.fake_example_inputs())
    671 _step_logger()(logging.INFO, f"done compiler function {name}")

File /opt/anaconda/envs/venv-mlc/lib/python3.10/site-packages/torch/_dynamo/debug_utils.py:1055, in wrap_backend_debug.<locals>.debug_wrapper(gm, example_inputs, **kwargs)
   1054 else:
-> 1055     compiled_gm = compiler_fn(gm, example_inputs)
   1057 return compiled_gm

File /opt/anaconda/envs/venv-mlc/lib/python3.10/site-packages/tvm/relax/frontend/torch/dynamo.py:161, in dynamo_capture_subgraphs.<locals>._capture(graph_module, example_inputs)
    160 input_info = [(tuple(tensor.shape), str(tensor.dtype)) for tensor in example_inputs]
--> 161 mod_ = from_fx(
    162     graph_module,
    163     input_info,
    164     keep_params_as_input=keep_params_as_input,
    165     unwrap_unit_return_tuple=True,
    166 )
    167 new_name = f"subgraph_{len(mod.get_global_vars())}"

File /opt/anaconda/envs/venv-mlc/lib/python3.10/site-packages/tvm/relax/frontend/torch/fx_translator.py:1492, in from_fx(model, input_info, keep_params_as_input, unwrap_unit_return_tuple, no_bind_return_tuple)
   1404 """Convert a PyTorch FX GraphModule to a Relax program
   1405 
   1406 Parameters
   (...)
   1490 check the placeholder rows in the beginning of the tabular.
   1491 """
-> 1492 return TorchFXImporter().from_fx(
   1493     model, input_info, keep_params_as_input, unwrap_unit_return_tuple, no_bind_return_tuple
   1494 )

File /opt/anaconda/envs/venv-mlc/lib/python3.10/site-packages/tvm/relax/frontend/torch/fx_translator.py:1377, in TorchFXImporter.from_fx(self, model, input_info, keep_params_as_input, unwrap_unit_return_tuple, no_bind_return_tuple)
   1375 func_name = node.name.rstrip("0123456789_")
   1376 assert (
-> 1377     func_name in self.convert_map
   1378 ), f"Unsupported function type {func_name}"
   1379 self.env[node] = self.convert_map[func_name](node)

AssertionError: Unsupported function type conv2d

The above exception was the direct cause of the following exception:

BackendCompilerFailed                     Traceback (most recent call last)
Cell In[13], line 1
----> 1 vae = vae_to_image(pipe)

Cell In[10], line 22, in vae_to_image(pipe)
     19 vae_to_image = VAEModelWrapper(vae)
     21 z = torch.rand((1, 4, 64, 64), dtype=torch.float32)
---> 22 mod = dynamo_capture_subgraphs(
     23     vae_to_image.forward,
     24     z,
     25     keep_params_as_input=True,
     26 )
     27 assert len(mod.functions) == 1
     29 return tvm.IRModule({"vae": mod["subgraph_0"]})

File /opt/anaconda/envs/venv-mlc/lib/python3.10/site-packages/tvm/relax/frontend/torch/dynamo.py:175, in dynamo_capture_subgraphs(model, *params, **kwargs)
    172 compiled_model = torch.compile(model, backend=_capture)
    174 with torch.no_grad():
--> 175     compiled_model(*params, **kwargs)
    177 return mod

I have same problem.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

6 participants