Skip to content

Official PyTorch implementation for ״ lassification-Regression for Chart Comprehension״

License

Notifications You must be signed in to change notification settings

levymsn/CQA-CRCT

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

14 Commits
 
 
 
 
 
 
 
 
 
 

Repository files navigation

[ECCV 2022] Classification-Regression for Chart Comprehension

Published in European Conference on Computer Vision 2022.

Classification-Regression for Chart Comprehension
Matan Levy, Rami Ben-Ari, Dani Lischinski

Abstract: Chart question answering (CQA) is a task used for assessing chart comprehension, which is fundamentally different from understanding natural images. CQA requires analyzing the relationships between the textual and the visual components of a chart, in order to answer general questions or infer numerical values. Most existing CQA datasets and models are based on simplifying assumptions that often enable surpassing human performance. In this work, we address this outcome and propose a new model that jointly learns classification and regression. Our language-vision setup uses co-attention transformers to capture the complex real-world interactions between the question and the textual elements. We validate our design with extensive experiments on the realistic PlotQA dataset, outperforming previous approaches by a large margin, while showing competitive performance on FigureQA. Our model is particularly well suited for realistic questions with out-of-vocabulary answers that require regression.


CRCT architecture, from the original paper.

Fig 3: Our Classification - Regression Chart Transformer (CRCT) network architecture consists
of two stages of detection and question answering. The detection stage (left) provides bounding
boxes and object representations of the visual and textual elements (see Fig. 2). These features,
along with the question text, enable the co-transformers in the second stage (right) to fuse both
visual and textual information into a pooled tuple of two single feature vectors {hv0 , hw0 }. Next,
our hybrid prediction head containing two different MLPs, outputs a classification score and a
regression result. co_i/self_i: co/self attention at block i.

Citation

In case you find this useful, please cite:

@inproceedings{levy2022classification,
  title={Classification-regression for chart comprehension},
  author={Levy, Matan and Ben-Ari, Rami and Lischinski, Dani},
  booktitle={European Conference on Computer Vision},
  pages={469--484},
  year={2022},
  organization={Springer}
}

Getting Started

Virtual Environment

Follow these steps:

  1. Set name: your_env_name in CRCT/environment.yml.
  2. Install environment via conda: conda env create -f CRCT/environment.yml.
  3. Install Detectron2 repository (from here).

Downloads

Dataset

The raw PlotQA dataset is available here.

Dataset Features

  • Post-Detection features: PlotQA images features, extracted from a pretrained Mask-RCNN.

  • Q&As files in .npy format: PlotQA data Q&As converted to numpy .npy format.

    qa_pairs_V1_train_10%.npy: A randomly selected 10% subset of the original V1 Question-Answers file. qa_pairs_test.npy: PlotQA test Question-Answers file.

Model Weights


Config

The PlotQA config is in CRCT/config/plotqa.json. Set the line "main_folder": "My/full/path/to/CRCT/", with your own path to the CRCT folder.
For more details see CRCT/config/README.md.

Training

Detection Stage: training Mask-RCNN

For training your own Mask-RCNN on PlotQA images, use:

cd Detector
python frcnn.py --output MyDetector --batch-size 128 --num-gpus 4

For more details, see Detectorn2.

Question-Answering Stage: training CRCT

Use the following command for training your own model:

python train.py -qa_file qa_pairs_V1_10%.npy -dataset_config config/plotqa.json -batch_size 80 -save_name MyOwnCRCT -num_workers 2 -ddp -world_size 4 -num_proc 4 -L1
  • -qa_file: The Q&As file, in .npy format.
  • -dataset_config: path to the config file.
  • -ddp: A flag for training with PyTorch's Distributed Data Parallel.
  • -world_size: number of GPUs.
  • -num_proc: number of processes on each machine. For training on a single machine, set it equal to world_size.
  • -L1: Training the regression head with L1 loss. Otherwise, L1Smooth will be applied.
  • -dist_url Distributed data parallelization file url. Make Sure it isn't exists before training!.

For more details, check options.py.

Evaluation

For evaluate a model , use the following command:

python evaluation.py -continue -qa_file qa_pairs_test.npy -num_workers 2 -ddp -world_size 4 -num_proc 4 -save_name MyEvalFolderCRCT -eval_set test -start_checkpoint crct.ckpt
  • -start_checkpoint: weights file in CRCT/checkpoints/.
  • -eval_set: choose test/val/train.
  • -continue: This flag will raise an error in case the weights are not suitable to the model.

Interactive Demo

Try the CRCT model yourself. Download weights/features/Q&As from above, and run:

cd CRCT
python Interactive_demo.py

Examples

For inference examples, please visit the Project page at https://www.vision.huji.ac.il/crct/.

Acknowledgements

  • Backbone was built on Jiasen Lu's ViLBERT, and Vishvak's implementation.
  • We wish to thank Nir Zabari and Or Kedar for their assistance in parts of this research.
  • Part of this research was conducted during an internship at IBM Research AI, Haifa.

About

Official PyTorch implementation for ״ lassification-Regression for Chart Comprehension״

Topics

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages