Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

LSTM Timeseries prediction example #1532

Draft
wants to merge 16 commits into
base: main
Choose a base branch
from

Conversation

NicoZweifel
Copy link
Contributor

@NicoZweifel NicoZweifel commented Mar 26, 2024

Checklist

  • Confirmed that run-checks all script has been executed.
  • Made sure the book is up to date with changes in this PR.

Changes

  • Adds a timeseries forecasting example using the LSTM that was added in Feat/lstm #370, using a Partial Dataset from Huggingface.
  • The Dataset is limited to 10000 entries at the moment. Training on the full Dataset seems to be buggy still. I am not sure if the normalization is messed up or if there might be a memory limitation or bug with the SqliteDataset.

I have not narrowed it down yet as I am using custom Datasets on my other burn project and they work fine (InMemory with data from alphavantage). I might need to spend some more time on it to figure it out but since it doesn't block me in my other goals and the example seems to work with 10000 entries I though I could publish this as a draft for now.

Testing

cargo run --example lstm --features tch-cpu

Copy link

codecov bot commented Mar 26, 2024

Codecov Report

All modified and coverable lines are covered by tests ✅

Project coverage is 86.31%. Comparing base (7705fd9) to head (9dedb3f).
Report is 59 commits behind head on main.

Current head 9dedb3f differs from pull request most recent head 8a8b57f

Please upload reports for the commit 8a8b57f to get more accurate results.

Additional details and impacted files
@@            Coverage Diff             @@
##             main    #1532      +/-   ##
==========================================
- Coverage   86.38%   86.31%   -0.07%     
==========================================
  Files         693      683      -10     
  Lines       80473    78091    -2382     
==========================================
- Hits        69519    67408    -2111     
+ Misses      10954    10683     -271     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

Copy link
Member

@nathanielsimard nathanielsimard left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice example, ping me when you think it's going to be ready for a review.

@wcshds
Copy link
Contributor

wcshds commented Mar 28, 2024

@nathanielsimard Hello, I am always interested in the implementation of lstm in burn. I still think lstm is buggy right now. If a linear layer is added after the lstm, the parameters of the lstm and all layers before it will not be updated during training. I've been stuck on this problem for a long time.

The example of using lstm in this PR further confirms that lstm does have problems. I add some code in training.rs to check the parameters of the model before and after training.

let pjr = PrettyJsonFileRecorder::<FullPrecisionSettings>::new();
model.input_layer.clone().save_file("./input-before.json", &pjr).unwrap();
model.lstm.clone().save_file("./lstm-before.json", &pjr).unwrap();
model.output_layer.clone().save_file("./output-before.json", &pjr).unwrap();

// ......

model_trained.input_layer.clone().save_file("./input-after.json", &pjr).unwrap();
model_trained.lstm.clone().save_file("./lstm-after.json", &pjr).unwrap();
model_trained.output_layer.clone().save_file("./output-after.json", &pjr).unwrap();

After training, only the parameters of the output_layer changed. Nevertheless, for the dataset in the example, only one linear layer might be enough to overfit.

@NicoZweifel
Copy link
Contributor Author

NicoZweifel commented Mar 28, 2024

@nathanielsimard Hello, I am always interested in the implementation of lstm in burn.

I was hoping I could spark the development of the LSTM implementation a bit with an example. I would love to use Burn for this purpose as well.

After training, only the parameters of the output_layer changed. Nevertheless, for the dataset in the example, only one linear layer might be enough to overfit.

Happy to incorporate your suggestions! Feel free to create a PR that makes changes to this branch.

@nathanielsimard
Copy link
Member

@wcshds @NicoZweifel I have identified the issue and we already have a planned fix. However, we will prioritize it as it directly affects a real-world use case. The problem lies in the autodiff graph, which is always attached to a tensor. When two tensors with different graphs interact, we merge the graphs. However, this process assumes that all nodes in the graph will eventually interact, which is not the case for LSTM. For instance, you may only use the hidden_state, but the graph is actually held by the gate_state, which explains the problem when working with the current LSTM implementation.

We already want to implement a client/server architecture in burn-autodiff to avoid graph merging, locking and to fix that problem.

@antimora
Copy link
Collaborator

@wcshds @NicoZweifel I have identified the issue and we already have a planned fix. However, we will prioritize it as it directly affects a real-world use case. The problem lies in the autodiff graph, which is always attached to a tensor. When two tensors with different graphs interact, we merge the graphs. However, this process assumes that all nodes in the graph will eventually interact, which is not the case for LSTM. For instance, you may only use the hidden_state, but the graph is actually held by the gate_state, which explains the problem when working with the current LSTM implementation.

We already want to implement a client/server architecture in burn-autodiff to avoid graph merging, locking and to fix that problem.

@nathanielsimard do we have a separate ticket of "planned fix"? It would go to track and link it here.

@NicoZweifel
Copy link
Contributor Author

@wcshds @NicoZweifel I have identified the issue and we already have a planned fix. However, we will prioritize it as it directly affects a real-world use case. The problem lies in the autodiff graph, which is always attached to a tensor. When two tensors with different graphs interact, we merge the graphs. However, this process assumes that all nodes in the graph will eventually interact, which is not the case for LSTM. For instance, you may only use the hidden_state, but the graph is actually held by the gate_state, which explains the problem when working with the current LSTM implementation.

We already want to implement a client/server architecture in burn-autodiff to avoid graph merging, locking and to fix that problem.

@nathanielsimard Kinda off topic but it would be cool to have a generic TimeSeriesDataset that supports windowing, similar to what other libraries have. If this is something that is desired I could try to look into it in a separate Issue/PR.

@antimora
Copy link
Collaborator

@wcshds @NicoZweifel I have identified the issue and we already have a planned fix. However, we will prioritize it as it directly affects a real-world use case. The problem lies in the autodiff graph, which is always attached to a tensor. When two tensors with different graphs interact, we merge the graphs. However, this process assumes that all nodes in the graph will eventually interact, which is not the case for LSTM. For instance, you may only use the hidden_state, but the graph is actually held by the gate_state, which explains the problem when working with the current LSTM implementation.
We already want to implement a client/server architecture in burn-autodiff to avoid graph merging, locking and to fix that problem.

@nathanielsimard Kinda off topic but it would be cool to have a generic TimeSeriesDataset that supports windowing, similar to what other libraries have. If this is something that is desired I could try to look into it in a separate Issue/PR.

@NicoZweifel, that would be a great addition. You can file an issue for this and we can assign it to you.

@NicoZweifel NicoZweifel mentioned this pull request Mar 28, 2024
2 tasks
@NicoZweifel
Copy link
Contributor Author

@wcshds @NicoZweifel I have identified the issue and we already have a planned fix. However, we will prioritize it as it directly affects a real-world use case. The problem lies in the autodiff graph, which is always attached to a tensor. When two tensors with different graphs interact, we merge the graphs. However, this process assumes that all nodes in the graph will eventually interact, which is not the case for LSTM. For instance, you may only use the hidden_state, but the graph is actually held by the gate_state, which explains the problem when working with the current LSTM implementation.
We already want to implement a client/server architecture in burn-autodiff to avoid graph merging, locking and to fix that problem.

@nathanielsimard Kinda off topic but it would be cool to have a generic TimeSeriesDataset that supports windowing, similar to what other libraries have. If this is something that is desired I could try to look into it in a separate Issue/PR.

@NicoZweifel, that would be a great addition. You can file an issue for this and we can assign it to you.

Thanks, I created a separate issue to discuss the details 👍

@NicoZweifel NicoZweifel mentioned this pull request Apr 10, 2024
4 tasks
Copy link
Contributor

This PR has been marked as stale because it has not been updated for over a month

@github-actions github-actions bot added the stale The issue or pr has been open for too long label May 19, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
example Related to examples stale The issue or pr has been open for too long
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants