C++

华为MindSpore数据集加载算子开发

Posted on 2021-05-14,28 min read

MindData 是 MindSpore 的数据处理系统, 为 MindSpore 提供了数据加载和预处理管 道。在训练场景,MindData 负责将训练数据从文件系统加载到训练系统,通过数据处理管道, 进行一系列变换和数据增强,最终组成 Tensor,输入到计算框架进行前向和 反向计算。在推理场景,MindData将推理数据加载到内存,通过预定义的变换后,以 Tensor 形式输入给计算框架进行推理。

MindData支持python层和C++层API定义数据加载和数据处理流水线,MindData会运行Execution Tree,树上的每个节点对应数据处理流水线中的一步具体操作,例如在数据加载后有各种数据增强的Map算子、Repeat算子,本文将介绍最基本的数据集加载算子。
一 个 完 整 的 数 据 处 理 算 子 包 含 四 部 分 : 算 子 Op 实 现 、 算 子 IR ( Intermediate Representation)层定义、Python 层接口定义和 C++层接口定义。

docker环境配置

docker pull mindspore/mindspore-cpu:devel
docker run -it -p 8023:22 -p 10022:10022 --name="mindspore" -v /home/docker_swap:/docker_swap mindspore/mindspore-cpu /bin/bash

进入docker后,下载mindspore代码,安装ssh和libboost-dev

apt udpate
apt upgrade
apt-get install libboost-dev
apt-get install openssh-server //使用vscode连接docker进行开发

cd ~
git clone https://gitee.com/yangyueren/mindspore.git

编译

cd mindspore

bash build.sh -e cpu –j24 –t on (wait for a long time)

pip install ./mindspore/build/package/mindspore-1.2.0-cp37-cp37m-linux_x86_64.whl

数据集加载算子开发

以places365数据集为例,详述数据集加载算子的开发过程。
完整代码:https://gitee.com/yangyueren/mindspore/tree/op_places365/

新建分支

git checkout –b places365_dataset

开发places365数据集加载算子需要修改或添加以下文件:

# 底层算子op开发,加载数据集的最底层的类
mindspore/ccsrc/minddata/dataset/engine/datasetops/source/CMakeLists.txt
mindspore/ccsrc/minddata/dataset/engine/datasetops/source/places365_op.cc
mindspore/ccsrc/minddata/dataset/engine/datasetops/source/places365_op.h


# 中间node表示层开发(IR层)
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/dataset_node.h
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/CMakeLists.txt
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/places365_node.cc
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/places365_node.h


# C++层API
mindspore/ccsrc/minddata/dataset/api/datasets.cc
mindspore/ccsrc/minddata/dataset/include/datasets.h
mindspore/ccsrc/minddata/dataset/include/samplers.h


# Python API,绑定到C++开发的算子上
mindspore/ccsrc/minddata/dataset/api/python/bindings/dataset/engine/datasetops/source/bindings.cc
mindspore/ccsrc/minddata/dataset/api/python/bindings/dataset/engine/ir/datasetops/source/bindings.cc

mindspore/dataset/engine/validators.py
mindspore/dataset/engine/datasets.py

底层Op算子开发

这里以places365数据集为例,Places365数据集包含365个场景,train-standard包含180万张图片,train-challenge包含800万张图片,val中包含36500张图片。
Places365数据集提供了两套上述数据集,分别是高分辨率和256*256低分辨率图像。数据集格式如下:

places365/                            // root directory of places365
|---categories_places365.txt          //两列,第一列是类别名,第二列是类别ID
|---places365_train-standard.txt      //两列,第一列是图片的路径,第二列是类别ID
|---places365_train-challenge.txt    //两列,第一列是图片的路径,第二列是类别ID
|---train_large_places365standard/   //存放train-standard高分辨率图片
|---train_large_places365challenge/  //存放train-challenge高分辨率图片
|---val_large/                       //存放val高分辨率图片
|---train_256_places365standard/     //存放train-standard低分辨率图片
|---train_256_places365standard/    //存放train-challenge低分辨率图片
|---val_256/                         //存放val低分辨率图片

本算子将读取数据集中的图片和label,将其封装为Tensor返回,列名为image和label。

在加载places365数据集时,需要指定以下参数:

@param std::string root - root directory of places365
@param const std::string &usage - Usage of this dataset, can be 'train-standard', 'train-challenge' or 'val'. Read the images in this folder and load this meta information.
@param bool small - Use high resolution images or 256*256 resolution images.
@param bool decode - Decode images

usage 决定了加载哪个数据集
usage 和 small参数共同决定了加载哪个文件夹、哪个分辨率大小的数据集
const std::map<std::pair<std::string, bool>, std::string> K_IMAGES_META = {
        {std::pair<std::string, bool>("train-standard", false), "train_large_places365standard"},
        {std::pair<std::string, bool>("train-challenge", false), "train_large_places365challenge"},
        {std::pair<std::string, bool>("val", false), "val_large"},
        {std::pair<std::string, bool>("train-standard", true), "train_256_places365standard"},
        {std::pair<std::string, bool>("train-challenge", true), "train_256_places365challenge"},
        {std::pair<std::string, bool>("val", true), "val_256"},
};

源码分析:

由于places365数据集是可以random access的,也即给定一个下标,可以直接取出该下标的图片和label,所以在数据集加载时,用户可以自定义指定sampler(加载哪些下标数据集,比如只加载前300条数据,就没必要把places365所有的数据都读入到内存里,再返回前300条),所以Places365Op继承了RandomAccessOp,主要实现LoadBuffer里的LoadTensorRow函数,LoadTensorRow函数原型为Status Places365Op::LoadTensorRow(row_id_type row_id, TensorRow *trow),将给定下标row_id的数据放入到trow里。

其余两个相关的函数分别为:
Status WalkAllFiles(); 解析图片路径和label信息,不用加载图片到内存。
Status LaunchThreadsAndInitOp(); 初始化。

数据集加载的算子逻辑是这样的,它一开始并不会加载图片到内存,它一开始只想知道一共有多少条数据,每个类别的数据对应的index是多少,所以Places365Op里有一个嵌套类Builder,Builder可以构造一个Places365Op,然后调用CountTotalRows函数来获取数据集的信息(注意,此时并不需要加载图片到内存,所以获取数据集的信息会非常快)。当获取了数据集的信息后(一共有多少条数据,每个类别的数据的下标),供给sampler使用,然后再构建Places365Op实例,并且指定要sample的数据的下标,进行数据的加载,这里使用到了多线程并行。
归根结底,这里是将数据信息的加载(非常快)和数据本身的加载(加载全部会比较慢)分开了,可以让用户指定加载哪些数据,从而不用每次都加载全部,而是直接根据下标进行加载相应的数据集。

places365_op.h

/**
 * Copyright 2019-2021 Huawei Technologies Co., Ltd
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_PLACES365_OP_H_
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_PLACES365_OP_H_

#include <memory>
#include <string>
#include <algorithm>
#include <map>
#include <vector>
#include <utility>
#include <opencv2/opencv.hpp>
#include <opencv2/core/utils/filesystem.hpp>

#include "minddata/dataset/core/tensor.h"
#include "minddata/dataset/engine/data_buffer.h"
#include "minddata/dataset/engine/data_schema.h"
#include "minddata/dataset/engine/datasetops/parallel_op.h"
#include "minddata/dataset/engine/datasetops/source/sampler/sampler.h"
#include "minddata/dataset/util/path.h"
#include "minddata/dataset/util/queue.h"
#include "minddata/dataset/util/status.h"
#include "minddata/dataset/util/wait_post.h"

namespace mindspore {
namespace dataset {
// Forward declares
template <typename T>
class Queue;

using Places365LabelPair = std::pair<std::shared_ptr<Tensor>, uint32_t>;

class Places365Op : public ParallelOp, public RandomAccessOp {
 public:
  class Builder {
   public:
    // Constructor for Builder class of Places365Op
    Builder();

    // Destructor.
    ~Builder() = default;

    // Setter method
    // @param int32_t rows_per_buffer
    // @return Builder setter method returns reference to the builder.
    Builder &SetRowsPerBuffer(int32_t rows_per_buffer) {
      builder_rows_per_buffer_ = rows_per_buffer;
      return *this;
    }

    // Setter method
    // @param int32_t op_connector_size
    // @return Builder setter method returns reference to the builder.
    Builder &SetOpConnectorSize(int32_t op_connector_size) {
      builder_op_connector_size_ = op_connector_size;
      return *this;
    }

    // Setter method
    // @param int32_t num_workers
    // @return Builder setter method returns reference to the builder.
    Builder &SetNumWorkers(int32_t num_workers) {
      builder_num_workers_ = num_workers;
      return *this;
    }

    // Setter method
    // @param std::shared_ptr<Sampler> sampler
    // @return Builder setter method returns reference to the builder.
    Builder &SetSampler(std::shared_ptr<SamplerRT> sampler) {
      builder_sampler_ = std::move(sampler);
      return *this;
    }

    // Setter method
    // @param const std::string &dir
    // @return
    Builder &SetDir(const std::string &dir) {
      builder_dir_ = dir;
      return *this;
    }

    // Setter method
    // @param const std::string &usage
    // @return
    Builder &SetUsage(const std::string &usage) {
      builder_usage_ = usage;
      return *this;
    }

    // Setter method
    // @param bool small
    // @return
    Builder &SetSmall(bool small) {
      builder_small_ = small;
      return *this;
    }
    // Setter method
    // @param bool decode
    // @return
    Builder &SetDecode(bool decode) {
      builder_decode_ = decode;
      return *this;
    }


    // Check validity of input args
    // @return Status The status code returned
    Status SanityCheck();

    // The builder "Build" method creates the final object.
    // @param std::shared_ptr<Places365Op> *op - DatasetOp
    // @return Status The status code returned
    Status Build(std::shared_ptr<Places365Op> *op);

   private:
    std::string builder_dir_;
    std::string builder_usage_;
    bool builder_small_;
    bool builder_decode_;
    int32_t builder_num_workers_;
    int32_t builder_rows_per_buffer_;
    int32_t builder_op_connector_size_;
    std::shared_ptr<SamplerRT> builder_sampler_;
    std::unique_ptr<DataSchema> builder_schema_;
  };

  // Constructor
  // @param std::string root - dir directory of places365
  // @param const std::string &usage - Usage of this dataset, can be 'train-standard', 'train-challenge' or 'val'
  // @param bool small - Use high resolution images or 256*256 resolution images
  // @param bool decode - Decode  jpg format images
  // @param int32_t num_workers - number of workers reading images in parallel
  // @param int32_t rows_per_buffer - number of images (rows) in each buffer
  // @param int32_t queue_size - connector queue size
  // @param std::unique_ptr<DataSchema> data_schema - the schema of the mnist dataset
  // @param td::unique_ptr<Sampler> sampler - sampler tells Places365Op what to read
  Places365Op(const std::string &root, const std::string &usage, bool small, bool decode, int32_t num_workers, int32_t rows_per_buffer, 
          int32_t queue_size, std::unique_ptr<DataSchema> data_schema, std::shared_ptr<SamplerRT> sampler);

  // Destructor.
  ~Places365Op() = default;

  // Worker thread pulls a number of IOBlock from IOBlock Queue, make a buffer and push it to Connector
  // @param int32_t worker_id - id of each worker
  // @return Status The status code returned
  Status WorkerEntry(int32_t worker_id) override;

  // Main Loop of Places365Op
  // Master thread: Fill IOBlockQueue, then goes to sleep
  // Worker thread: pulls IOBlock from IOBlockQueue, work on it then put buffer to mOutConnector
  // @return Status The status code returned
  Status operator()() override;

  // Method derived from RandomAccess Op, enable Sampler to get all ids for each class
  // @param (std::map<uint64_t, std::vector<uint64_t >> * map - key label, val all ids for this class
  // @return Status The status code returned
  Status GetClassIds(std::map<int32_t, std::vector<int64_t>> *cls_ids) const override;

  // A print method typically used for debugging
  // @param out
  // @param show_all
  void Print(std::ostream &out, bool show_all) const override;

  // Function to count the number of samples in the Places365 dataset
  // @param dir path to the Places365 directory
  // @param const std::string &usage - Usage of this dataset, can be 'train-standard', 'train-challenge' or 'val'
  // @param const bool small - Use high resolution images or 256*256 resolution images
  // @param const bool decode - Decode  jpg format images
  // @param count output arg that will hold the minimum of the actual dataset size and numSamples
  // @return
  static Status CountTotalRows(const std::string &dir, const std::string &usage, const bool small, const bool decode, int64_t *count);

  // Op name getter
  // @return Name of the current Op
  std::string Name() const override { return "Places365Op"; }

 private:
  // Initialize Sampler, calls sampler->Init() within
  // @return Status The status code returned
  Status InitSampler();

  // Load a tensor row according to a pair
  // @param row_id_type row_id - id for this tensor row
  // @param ImageLabelPair pair - <imagefile,label>
  // @param TensorRow row - image & label read into this tensor row
  // @return Status The status code returned
  Status LoadTensorRow(row_id_type row_id, TensorRow *row);

  // @param const std::vector<int64_t> &keys - keys in ioblock
  // @param std::unique_ptr<DataBuffer> db
  // @return Status The status code returned
  Status LoadBuffer(const std::vector<int64_t> &keys, std::unique_ptr<DataBuffer> *db);

  // Iterate through all members in sampleIds and fill them into IOBlock.
  // @param std::shared_ptr<Tensor> sample_ids -
  // @param std::vector<int64_t> *keys - keys in ioblock
  // @return Status The status code returned
  Status TraversalSampleIds(const std::shared_ptr<Tensor> &sample_ids, std::vector<int64_t> *keys);

  // Load the meta information of categories.
  // @param const std::string &category_meta_name
  // @return Status The status code returned
  Status LoadCategories(const std::string &category_meta_name);

  // Load the meta information of file infomation.
  // @param const std::string &filelists_meta_name
  // @return Status The status code returned
  Status LoadFileLists(const std::string &filelists_meta_name);

  // Get one piece of places365 data
  // @param uint32_t index Index of the datas
  // @param std::shared_ptr<Tensor> *image_tensor Store the result in image_tensor
  // @return Status The status code returned
  Status GetPlaces365DataTensor(uint32_t index, std::shared_ptr<Tensor> *image_tensor);

  // Read all files in the directory
  // @return Status The status code returned
  Status WalkAllFiles();

  // Called first when function is called
  // @return Status The status code returned
  Status LaunchThreadsAndInitOp();



  // reset Op
  // @return Status The status code returned
  Status Reset() override;

  // Private function for computing the assignment of the column name map.
  // @return - Status
  Status ComputeColMap() override;

  int64_t buf_cnt_;
  int64_t row_cnt_;
  int32_t rows_per_buffer_;
  std::unique_ptr<DataSchema> data_schema_;

  const std::string root_; // directory of image folder
  const std::string usage_; // can only be "train-challenge", "train-standard" or "val"
  const bool small_;
  const bool decode_;

  std::map<std::string, int> categorie2id_;

  std::vector<std::pair<std::string, uint32_t>> image_path_label_pairs_;

  // std::vector<Places365LabelPair> image_label_pairs_;

};
}  // namespace dataset
}  // namespace mindspore
#endif  // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_PLACES365_OP_H_

places365_op.cc

/**
 * Copyright 2019-2021 Huawei Technologies Co., Ltd
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
#include "minddata/dataset/engine/datasetops/source/places365_op.h"

#include <iostream>
#include <fstream>
#include <iomanip>
#include <set>
#include "utils/ms_utils.h"
#include "minddata/dataset/core/config_manager.h"
#include "minddata/dataset/core/tensor_shape.h"
#include "minddata/dataset/engine/datasetops/source/sampler/sequential_sampler.h"
#include "minddata/dataset/engine/db_connector.h"
#include "minddata/dataset/engine/execution_tree.h"
#ifndef ENABLE_ANDROID
#include "minddata/dataset/kernels/image/image_utils.h"
#else
#include "minddata/dataset/kernels/image/lite_image_utils.h"
#endif

namespace mindspore {
namespace dataset {
const std::string K_CATEGORIES_META = "categories_places365.txt";
const std::map<std::string, std::string> K_FILE_LIST_META = {
    {"train-standard", "places365_train_standard.txt"},
    {"train-challenge", "places365_train_challenge.txt"},
    {"val", "places365_val.txt"}
};
const std::map<std::pair<std::string, bool>, std::string> K_IMAGES_META = {
        {std::pair<std::string, bool>("train-standard", false), "train_large_places365standard"},
        {std::pair<std::string, bool>("train-challenge", false), "train_large_places365challenge"},
        {std::pair<std::string, bool>("val", false), "val_large"},
        {std::pair<std::string, bool>("train-standard", true), "train_256_places365standard"},
        {std::pair<std::string, bool>("train-challenge", true), "train_256_places365challenge"},
        {std::pair<std::string, bool>("val", true), "val_256"},
};

Places365Op::Builder::Builder() : builder_sampler_(nullptr), builder_usage_("train-standard"), builder_small_(true), builder_decode_(true) {
  std::shared_ptr<ConfigManager> cfg = GlobalContext::config_manager();
  builder_num_workers_ = cfg->num_parallel_workers();
  builder_rows_per_buffer_ = cfg->rows_per_buffer();
  builder_op_connector_size_ = cfg->op_connector_size();
}

Status Places365Op::Builder::Build(std::shared_ptr<Places365Op> *ptr) {
  RETURN_IF_NOT_OK(SanityCheck());
  if (builder_sampler_ == nullptr) {
    const int64_t num_samples = 0;
    const int64_t start_index = 0;
    builder_sampler_ = std::make_shared<SequentialSamplerRT>(start_index, num_samples);
  }
  builder_schema_ = std::make_unique<DataSchema>();
  RETURN_IF_NOT_OK(
    builder_schema_->AddColumn(ColDescriptor("image", DataType(DataType::DE_UINT8), TensorImpl::kCv, 1)));
  TensorShape scalar = TensorShape::CreateScalar();
  RETURN_IF_NOT_OK(builder_schema_->AddColumn(
    ColDescriptor("label", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 0, &scalar)));
  *ptr = std::make_shared<Places365Op>(builder_dir_, builder_usage_, builder_small_, builder_decode_, builder_num_workers_, builder_rows_per_buffer_,
                                   builder_op_connector_size_, std::move(builder_schema_), std::move(builder_sampler_));
  return Status::OK();
}

Status Places365Op::Builder::SanityCheck() {
  const std::set<std::string> valid = {"train-standard", "train-challenge", "val"};
  Path dir(builder_dir_);
  std::string err_msg;
  err_msg += dir.IsDirectory() == false
               ? "Invalid parameter, MNIST path is invalid or not set, path: " + builder_dir_ + ".\n"
               : "";
  err_msg += builder_num_workers_ <= 0 ? "Invalid parameter, num_parallel_workers must be greater than 0, but got " +
                                           std::to_string(builder_num_workers_) + ".\n"
                                       : "";
  err_msg += valid.find(builder_usage_) == valid.end()
               ? "Invalid parameter, usage must be 'train-standard', 'train-challenge', 'val', but got " + builder_usage_ + ".\n"
               : "";
  return err_msg.empty() ? Status::OK() : Status(StatusCode::kMDUnexpectedError, __LINE__, __FILE__, err_msg);
}

Places365Op::Places365Op(const std::string &root, const std::string &usage, bool small, bool decode, int32_t num_workers, int32_t rows_per_buffer, 
          int32_t queue_size, std::unique_ptr<DataSchema> data_schema, std::shared_ptr<SamplerRT> sampler)
    : ParallelOp(num_workers, queue_size, std::move(sampler)),
      root_(root),
      usage_(usage),
      small_(small),
      decode_(decode),
      buf_cnt_(0),
      row_cnt_(0),
      rows_per_buffer_(rows_per_buffer),
      categorie2id_({}),
      image_path_label_pairs_({}),
      data_schema_(std::move(data_schema)) {
  io_block_queues_.Init(num_workers, queue_size);
}

Status Places365Op::TraversalSampleIds(const std::shared_ptr<Tensor> &sample_ids, std::vector<int64_t> *keys) {
  for (auto itr = sample_ids->begin<int64_t>(); itr != sample_ids->end<int64_t>(); ++itr) {
    if ((*itr) >= num_rows_) continue;  // index out of bound, skipping
    keys->push_back(*itr);
    row_cnt_++;
    if (row_cnt_ % rows_per_buffer_ == 0) {
      RETURN_IF_NOT_OK(io_block_queues_[buf_cnt_++ % num_workers_]->Add(
        std::make_unique<IOBlock>(IOBlock(*keys, IOBlock::kDeIoBlockNone))));
      keys->clear();
    }
  }
  return Status::OK();
}

// functor that contains the main logic of Places365 op
Status Places365Op::operator()() {
  RETURN_IF_NOT_OK(LaunchThreadsAndInitOp());
  std::unique_ptr<DataBuffer> sampler_buffer;
  RETURN_IF_NOT_OK(sampler_->GetNextSample(&sampler_buffer));
  while (true) {  // each iterator is 1 epoch
    std::vector<int64_t> keys;
    keys.reserve(rows_per_buffer_);
    while (sampler_buffer->eoe() == false) {
      std::shared_ptr<Tensor> sample_ids;
      RETURN_IF_NOT_OK(sampler_buffer->GetTensor(&sample_ids, 0, 0));
      if (sample_ids->type() != DataType(DataType::DE_INT64)) {
        RETURN_STATUS_UNEXPECTED("Invalid parameter, data type of Sampler Tensor isn't int64, got " +
                                 sample_ids->type().ToString());
      }
      RETURN_IF_NOT_OK(TraversalSampleIds(sample_ids, &keys));
      RETURN_IF_NOT_OK(sampler_->GetNextSample(&sampler_buffer));
    }
    if (keys.empty() == false) {
      RETURN_IF_NOT_OK(io_block_queues_[(buf_cnt_++) % num_workers_]->Add(
        std::make_unique<IOBlock>(IOBlock(keys, IOBlock::kDeIoBlockNone))));
    }
    if (IsLastIteration()) {
      RETURN_IF_NOT_OK(
        io_block_queues_[(buf_cnt_++) % num_workers_]->Add(std::make_unique<IOBlock>(IOBlock::kDeIoBlockFlagEoe)));
      RETURN_IF_NOT_OK(
        io_block_queues_[(buf_cnt_++) % num_workers_]->Add(std::make_unique<IOBlock>(IOBlock::kDeIoBlockFlagEof)));
      for (int32_t i = 0; i < num_workers_; ++i) {
        RETURN_IF_NOT_OK(
          io_block_queues_[i]->Add(std::make_unique<IOBlock>(std::vector<int64_t>(), IOBlock::kDeIoBlockNone)));
      }
      return Status::OK();
    } else {
      RETURN_IF_NOT_OK(
        io_block_queues_[(buf_cnt_++) % num_workers_]->Add(std::make_unique<IOBlock>(IOBlock::kDeIoBlockFlagEoe)));
    }

    if (epoch_sync_flag_) {
      // If epoch_sync_flag_ is set, then master thread sleeps until all the worker threads have finished their job for
      // the current epoch.
      RETURN_IF_NOT_OK(WaitForWorkers());
    }
    // If not the last repeat, self-reset and go to loop again.
    if (!IsLastIteration()) {
      RETURN_IF_NOT_OK(Reset());
      RETURN_IF_NOT_OK(sampler_->GetNextSample(&sampler_buffer));
    }
    UpdateRepeatAndEpochCounter();
  }
}

// contains the logic of pulling a IOBlock from IOBlockQueue, load a buffer and push the buffer to out_connector_
Status Places365Op::WorkerEntry(int32_t worker_id) {
  TaskManager::FindMe()->Post();
  int64_t buffer_id = worker_id;
  std::unique_ptr<IOBlock> iOBlock;
  RETURN_IF_NOT_OK(io_block_queues_[worker_id]->PopFront(&iOBlock));
  while (iOBlock != nullptr) {
    if (iOBlock->wait() == true) {
      // Sync io_block is a signal that master thread wants us to pause and sync with other workers.
      // The last guy who comes to this sync point should reset the counter and wake up the master thread.
      if (++num_workers_paused_ == num_workers_) {
        wait_for_workers_post_.Set();
      }
    } else if (iOBlock->eoe() == true) {
      RETURN_IF_NOT_OK(out_connector_->Add(worker_id, std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOE)));
      buffer_id = worker_id;
    } else if (iOBlock->eof() == true) {
      RETURN_IF_NOT_OK(out_connector_->Add(worker_id, std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOF)));
    } else {
      std::vector<int64_t> keys;
      RETURN_IF_NOT_OK(iOBlock->GetKeys(&keys));
      if (keys.empty() == true) return Status::OK();  // empty key is a quit signal for workers
      std::unique_ptr<DataBuffer> db = std::make_unique<DataBuffer>(buffer_id, DataBuffer::kDeBFlagNone);
      RETURN_IF_NOT_OK(LoadBuffer(keys, &db));
      RETURN_IF_NOT_OK(out_connector_->Add(worker_id, std::move(db)));
      buffer_id += num_workers_;
    }
    RETURN_IF_NOT_OK(io_block_queues_[worker_id]->PopFront(&iOBlock));
  }
  RETURN_STATUS_UNEXPECTED("Unexpected nullptr received in worker.");
}

Status Places365Op::LoadTensorRow(row_id_type row_id, TensorRow *trow) {
  std::shared_ptr<Tensor> image, label;
  // make a copy of cached tensor
  RETURN_IF_NOT_OK(GetPlaces365DataTensor(row_id, &image));
  RETURN_IF_NOT_OK(Tensor::CreateScalar(image_path_label_pairs_[row_id].second, &label));

  (*trow) = TensorRow(row_id, {std::move(image), std::move(label)});
  // trow->setPath({image_path_[row_id], label_path_[row_id]});
  return Status::OK();
}

// Looping over LoadTensorRow to make 1 DataBuffer. 1 function call produces 1 buffer
Status Places365Op::LoadBuffer(const std::vector<int64_t> &keys, std::unique_ptr<DataBuffer> *db) {
  std::unique_ptr<TensorQTable> deq = std::make_unique<TensorQTable>();
  TensorRow trow;
  for (const int64_t &key : keys) {
    RETURN_IF_NOT_OK(this->LoadTensorRow(key, &trow));
    deq->push_back(std::move(trow));
  }
  (*db)->set_tensor_table(std::move(deq));
  return Status::OK();
}

void Places365Op::Print(std::ostream &out, bool show_all) const {
  if (!show_all) {
    // Call the super class for displaying any common 1-liner info
    ParallelOp::Print(out, show_all);
    // Then show any custom derived-internal 1-liner info for this op
    out << "\n";
  } else {
    // Call the super class for displaying any common detailed info
    ParallelOp::Print(out, show_all);
    // Then show any custom derived-internal stuff
    out << "\nNumber of rows:" << num_rows_ << "\nPlaces365 Directory: " << root_ << "\n\n";
  }
}

// Reset Sampler and wakeup Master thread (functor)
Status Places365Op::Reset() {
  MS_LOG(DEBUG) << Name() << " performing a self-reset.";
  RETURN_IF_NOT_OK(sampler_->ResetSampler());
  row_cnt_ = 0;
  return Status::OK();
}

// hand shake with Sampler, allow Sampler to call RandomAccessOp's functions to get NumRows
Status Places365Op::InitSampler() {
  RETURN_IF_NOT_OK(sampler_->HandshakeRandomAccessOp(this));
  return Status::OK();
}

// Derived from RandomAccessOp
Status Places365Op::GetClassIds(std::map<int32_t, std::vector<int64_t>> *cls_ids) const {
  if (cls_ids == nullptr || !cls_ids->empty() || image_path_label_pairs_.empty()) {
    if (image_path_label_pairs_.empty()) {
      RETURN_STATUS_UNEXPECTED("No image found in dataset, please check if Op read images successfully or not.");
    } else {
      RETURN_STATUS_UNEXPECTED(
        "Map for storaging image-index pair is nullptr or has been set in other place,"
        "it must be empty before using GetClassIds.");
    }
  }
  for (size_t i = 0; i < image_path_label_pairs_.size(); ++i) {
    (*cls_ids)[image_path_label_pairs_[i].second].push_back(i);
  }
  for (auto &pair : (*cls_ids)) {
    pair.second.shrink_to_fit();
  }
  return Status::OK();
}

// Load the meta information of categories.
// @param const std::string &category_meta_name
// @return Status The status code returned
Status Places365Op::LoadCategories(const std::string &category_meta_name){
    std::ifstream reader(category_meta_name);
    // std::cout << category_meta_name << std::endl;
    CHECK_FAIL_RETURN_UNEXPECTED(!reader.fail(), category_meta_name + " File not exists!");
    std::string path;
    int label;
    
    while (reader >> path >> label){
        categorie2id_.insert({path, label});
    }
    reader.close();
    return Status::OK();

}

// Load the meta information of file infomation.
// @param const std::string &filelists_meta_name
// @return Status The status code returned
Status Places365Op::LoadFileLists(const std::string &filelists_meta_name){
    // std::cout << filelists_meta_name << std::endl;
    std::ifstream reader(filelists_meta_name);
    CHECK_FAIL_RETURN_UNEXPECTED(!reader.fail(), filelists_meta_name + " File not exists!");
    std::string path;
    int label;
    std::string folder_path = cv::utils::fs::join(root_, K_IMAGES_META.at(std::make_pair(usage_, small_)));
    image_path_label_pairs_.clear();
    while (reader >> path >> label){
        image_path_label_pairs_.push_back({cv::utils::fs::join(folder_path, path), label});
    }
    reader.close();
    return Status::OK();
}

// Get one piece of places365 data
// @param uint32_t index Index of the datas
// @param std::shared_ptr<Tensor> *image_tensor Store the result in image_tensor
// @return Status The status code returned
Status Places365Op::GetPlaces365DataTensor(uint32_t index, std::shared_ptr<Tensor> *image_tensor){

    std::string file_path = image_path_label_pairs_[index].first;

    RETURN_IF_NOT_OK(Tensor::CreateFromFile(file_path, image_tensor));
    if(decode_){
        Status rc = Decode(*image_tensor, image_tensor);
        if (rc.IsError()) {
          *image_tensor = nullptr;
          std::string err_msg = "Invalid data, failed to decode image: " + file_path;
          return Status(StatusCode::kMDUnexpectedError, __LINE__, __FILE__, err_msg);
        }
    }

    

    return Status::OK();
}

// Read all files in the directory
// @return Status The status code returned
Status Places365Op::WalkAllFiles(){
    RETURN_IF_NOT_OK(LoadCategories(cv::utils::fs::join(root_, K_CATEGORIES_META)));
    RETURN_IF_NOT_OK(LoadFileLists(cv::utils::fs::join(root_, K_FILE_LIST_META.at(usage_))));
    num_rows_ = image_path_label_pairs_.size();
    if (num_rows_ == 0) {
      RETURN_STATUS_UNEXPECTED(
        "Invalid data, no valid data matching the dataset API Places365Dataset. Please check file path or dataset API.");
    }
    return Status::OK();
}

// Called first when function is called
// @return Status The status code returned
Status Places365Op::LaunchThreadsAndInitOp(){
    if (tree_ == nullptr) {
        RETURN_STATUS_UNEXPECTED("Pipeline init failed, Execution tree not set.");
    }
    RETURN_IF_NOT_OK(io_block_queues_.Register(tree_->AllTasks()));
    RETURN_IF_NOT_OK(wait_for_workers_post_.Register(tree_->AllTasks()));
    RETURN_IF_NOT_OK(
            tree_->LaunchWorkers(num_workers_, std::bind(&Places365Op::WorkerEntry, this, std::placeholders::_1), "", id()));
    TaskManager::FindMe()->Post();
    RETURN_IF_NOT_OK(this->WalkAllFiles());
    RETURN_IF_NOT_OK(this->InitSampler());  // handle shake with sampler



    return Status::OK();
}


Status Places365Op::CountTotalRows(const std::string &dir, const std::string &usage, const bool small, const bool decode, int64_t *count) {
  // the logic of counting the number of samples is copied from ParseMnistData() and uses CheckReader()
  std::shared_ptr<Places365Op> op;
  *count = 0;
  RETURN_IF_NOT_OK(Builder().SetDir(dir).SetUsage(usage).SetSmall(small).SetDecode(decode).Build(&op));

  RETURN_IF_NOT_OK(op->WalkAllFiles());

  for (size_t i = 0; i < op->image_path_label_pairs_.size(); ++i) {
    CHECK_FAIL_RETURN_UNEXPECTED(cv::utils::fs::exists(op->image_path_label_pairs_[i].first),
                                 "Invalid data, num of images is not equal to num of labels.");
  }
  *count = op->image_path_label_pairs_.size();

  return Status::OK();
}

Status Places365Op::ComputeColMap() {
  // set the column name map (base class field)
  if (column_name_id_map_.empty()) {
    for (int32_t i = 0; i < data_schema_->NumColumns(); ++i) {
      column_name_id_map_[data_schema_->column(i).name()] = i;
    }
  } else {
    MS_LOG(WARNING) << "Column name map is already set!";
  }
  return Status::OK();
}


}  // namespace dataset
}  // namespace mindspore


下一篇: MIT6.S081 syscall→