import numpy as np

# 如果使用labelme打标
import skimage
from PIL import Image
import yaml
import os
import cv2 as cv
def extract_bboxes(mask):  # [[num_instances, (y1, x1, y2, x2)]]
    # in a word,bbox proced by  mask will contain all mask which value equal 1.
    """Compute bounding boxes from masks.
    mask: [height, width, num_instances]. Mask pixels are either 1 or 0.

    Returns: bbox array [num_instances, (y1, x1, y2, x2)].
    """
    boxes = np.zeros([mask.shape[-1], 4], dtype=np.int32)
    # the last dimension for mask (num_instances) is bbox for instance every picture
    for i in range(mask.shape[-1]):
        m = mask[:, :, i]
        # Bounding box.
        horizontal_indicies = np.where(np.any(m, axis=0))[0]
        vertical_indicies = np.where(np.any(m, axis=1))[0]
        if horizontal_indicies.shape[0]:
            x1, x2 = horizontal_indicies[[0, -1]]
            y1, y2 = vertical_indicies[[0, -1]]
            # x2 and y2 should not be part of the box. Increment by 1.
            x2 += 1
            y2 += 1
        else:
            # No mask for this instance. Might happen due to
            # resizing or cropping. Set bbox to zeros
            x1, x2, y1, y2 = 0, 0, 0, 0
        boxes[i] = np.array([y1, x1, y2, x2])
    return boxes.astype(np.int32)
def img_box_from_labelme(img_file_path, classes):
    '''
    :param img_file_path:  要读取文件夹的路径,该文件夹包含一个info.yaml,一个label.png,一个img.png ,是 labelme 打标的图片
    :param classes: 该列表包含类别标签
    :return: 返回原图的矩阵和该图对应的boxes,其中boxes包含左上角与右小角及类别  [n,5]
    '''
    yaml_path = os.path.join(img_file_path, 'info.yaml')  # label_names: - _background_  - NG
    mask_path = os.path.join(img_file_path, 'label.png')
    img_path = os.path.join(img_file_path, 'img.png')
    image = np.array(cv.imread(img_path))
    img = Image.open(mask_path)  # loading mask_path from label_image that original image handled have changed mask image with label
    num_obj = np.max(img)  # 取一个最大值,得到验证有多少个物体就会是多少,如这张图有3个mask则该值等于3
    h,w,_=image.shape
    mask = np.zeros([h, w, num_obj], dtype=np.uint8)
    for index in range(num_obj):
        for i in range(w):
            for j in range(h):
                # info['width'] 与info['height'] 为label.png图像的宽度与高度
                at_pixel = img.getpixel((i, j))
                if at_pixel == index + 1:
                    mask[j, i, index] = 1  # 将有mask位置取1

    mask = mask.astype(np.uint8)

    boxes = np.zeros((num_obj, 5))
    boxes[:,:4] = extract_bboxes(mask)
    with open(yaml_path) as f:
        temp = yaml.load(f.read(), Loader=yaml.FullLoader)
        labels = temp['label_names']
        del labels[0]

    num_classes=len(classes)
    num_label=len(labels)
    for i in range(num_label):
        for j in range(num_classes):
            if labels[i] == classes[j]:
                boxes[i, -1] = j+1       # 因为不包含0背景,所以是+1

    return image, boxes