Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] after simplifing the model swin_tiny_patch4_window7_224 which was created by timm, onnxruntime got errors #307

Open
lichun-wang opened this issue Sep 4, 2023 · 1 comment

Comments

@lichun-wang
Copy link

Describe the bug
I use timm to create 'swin_tiny_patch4_window7_224', and then I use torch.onnx.export to export [ swin_tiny_patch4_window7_224.onnx ] model. After that , I use onnxruntime to run the model , It's OK.

But After I use onnx_simplifier to simplify swin_tiny.onnx, Onnxruntime cannot run ,and got errors like this:

onnxruntime.capi.onnxruntime_pybind11_state.Fail: [ONNXRuntimeError] : 1 : FAIL : Load model from swin_tiny_patch4_window7_224.onnx failed:Node (/layers/layers.0/blocks/blocks.0/Reshape_6) Op (Reshape) [ShapeInferenceError] Dimension could not be inferred: incompatible shapes

My Code

import timm
import onnxruntime as ort
import torch

model = timm.create_model("swin_tiny_patch4_window7_224", pretrained=True, num_classes=1000).eval()

 # input

dummy_input = torch.randn(*(1, 3, 224, 224), device='cpu')

onnx_path = 'swin_tiny_patch4_window7_224.onnx'

torch.onnx.export(model,
                dummy_input,
                onnx_path,
                verbose=False, 
                opset_version=17,
                do_constant_folding=True,  
                keep_initializers_as_inputs=True, 
                input_names=["input"],      
                output_names=["output"],  
                dynamic_axes={"input":{0:"batch_size"},"output":{0:"batch_size"}}
                )

import onnx
from onnxsim import onnx_simplifier
onnx_model = onnx.load(onnx_path)
onnx.checker.check_model(onnx_model)
model_simp, check = onnx_simplifier.simplify(onnx_model, check_n = 0)
onnx.save(model_simp, onnx_path)
print(f"simplify over : {onnx_path}  ")


ort_sess = ort.InferenceSession(onnx_path)
outputs = ort_sess.run(None, {'input': dummy_input.numpy().astype('float32')})
@Collonville
Copy link

Have a same error 😢

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants