import * as tf from '@tensorflow/tfjs';
import * as tfvis from '@tensorflow/tfjs-vis';


function regExp(str) {
    var reg = /[\{\}\[\]\/?.;:|\)*~`!^\-_+<>@\#$%&\\\=\(\'\"]/gi
    //특수문자 검증
    if (reg.test(str)) {
        //특수문자 제거후 리턴
        return str.replace(reg, "");
    }
    else {
        //특수문자가 없으므로 본래 문자 리턴
        return str;
    }
}



export function makeLayerFromLayerInfo(layer_info) {
    console.log("===============" + layer_info.name);
    const params_from_config = makeParamsFromConfig(layer_info.config);
    //console.log("params_from_config:", params_from_config);
    //console.log("layer_info.class_name:", layer_info.class_name);
    
    //delete params_from_config.group; //정보에서 그룹 뺴버릴 것!==>다른 위치로!
    if (params_from_config.units) params_from_config.units *= 1;
    if (layer_info.class_name != "InputLayer") delete params_from_config.batchInputShape;
    //console.log("!!!!!!!!!!!!!!!!!!!!!!!!!!!");
    
    var layer;
    switch (layer_info.class_name) {
        //Input
        case "Input":
        case "InputLayer":
            //console.log("batchInputShape:",params_from_config.batchInputShape);
            if (params_from_config.batchInputShape) {
                layer = tf.input({ shape: params_from_config.batchInputShape, name: params_from_config.name }); //입력부분은 Symbolic Tensor로 바꿔줘야 한다..
            }
            else if (params_from_config.shape) {
                regExp(params_from_config.shape)
                const toNumbers = arr => arr.map(Number);
                const shape = toNumbers(regExp(params_from_config.shape).split(","));
                layer = tf.input({ shape: shape, name: params_from_config.name }); //입력부분은 Symbolic Tensor로 바꿔줘야 한다..
            }
            break;
            //BASIC
        case "Activation":
            layer = tf.layers.activation(params_from_config);
            break;
        case "Dense":
            layer = tf.layers.dense(params_from_config);
            break;
        case "Dropout":
            layer = tf.layers.dropout(params_from_config);
            break;
        case "Embedding":
            layer = tf.layers.embedding(params_from_config);
            break;
        case "Flatten":
            layer = tf.layers.flatten(params_from_config);
            break;
        case "Permute":
            layer = tf.layers.permute(params_from_config);
            break;
        case "RepeatVector":
            layer = tf.layers.repeatVector(params_from_config);
            break;
        case "Reshape":
            layer = tf.layers.reshape(params_from_config);
            break;
        case "SpatialDropout1d":
            layer = tf.layers.spatialDropout1d(params_from_config);
            break;
            //Convolutional
        case "Conv1d":
            layer = tf.layers.conv1d(params_from_config);
            break;
        case "Conv2d":
            layer = tf.layers.conv1d(params_from_config);
            break;
        case "Conv2dTranspose":
            layer = tf.layers.conv2dTranspose(params_from_config);
            break;
        case "Conv3d":
            layer = tf.layers.conv3d(params_from_config);
            break;
        case "Cropping2D":
            layer = tf.layers.cropping2D(params_from_config);
            break;
        case "DepthwiseConv2d":
            layer = tf.layers.depthwiseConv2d(params_from_config);
            break;
        case "SeparableConv2d":
            layer = tf.layers.separableConv2d(params_from_config);
            break;
        case "UpSampling2d":
            layer = tf.layers.upSampling2d(params_from_config);
            break;
            // Merge
        case "Add":
            layer = tf.layers.add(params_from_config);
            break;
        case "Average":
            layer = tf.layers.average(params_from_config);
            break;
        case "Concatenate":
            layer = tf.layers.concatenate(params_from_config);
            break
            ///*
        case "Dot":
            layer = tf.layers.dot(params_from_config);
            break;
        case "Maximum":
            layer = tf.layers.maximum(params_from_config);
            break;
        case "Minimum":
            layer = tf.layers.minimum(params_from_config);
            break;
        case "Multiply":
            layer = tf.layers.multiply(params_from_config);
            break;
            //Normalization
        case "batchNormalization":
            layer = tf.layers.batchNormalization(params_from_config);
            break;
        case "layerNormalization":
            layer = tf.layers.layerNormalization(params_from_config);
            break;
            //Pooling
        case "averagePooling1d":
            layer = tf.layers.averagePooling1d(params_from_config);
            break;
        case "averagePooling2d":
            layer = tf.layers.averagePooling2d(params_from_config);
            break;
        case "averagePooling3d":
            layer = tf.layers.averagePooling3d(params_from_config);
            break;
        case "globalAveragePooling1d":
            layer = tf.layers.globalAveragePooling1d(params_from_config);
            break;
        case "globalAveragePooling2d":
            layer = tf.layers.globalAveragePooling2d(params_from_config);
            break;
        case "globalMaxPooling1d":
            layer = tf.layers.globalMaxPooling1d(params_from_config);
            break;
        case "globalMaxPooling2d":
            layer = tf.layers.globalMaxPooling2d(params_from_config);
            break;
        case "maxPooling1d":
            layer = tf.layers.maxPooling1d(params_from_config);
            break;
        case "maxPooling2d":
            layer = tf.layers.maxPooling2d(params_from_config);
            break;
        case "maxPooling3d":
            layer = tf.layers.maxPooling3d(params_from_config);
            break;
            //*/
    }

    //console.log(layer);
    //layer.name = params_from_config.name;
    //console.log(layer.name, params_from_config.name);
    return layer;
}

function makeParamsFromConfig(layer_config) {
    var params = {};
    for (var prop in layer_config) {
        const key = changeConfigKeyToParamsKey(prop);
        //console.log("key:",key);
        if (layer_config[prop] != null) params[key] = layer_config[prop];
    }
    delete params.biasInitializer;
    delete params.kernelInitializer;
    return params;
}

/**
 * 
 * config에서 사용하는 키가 GNU Naming Convention로 되어 있어서
 * CamelCasing으로 바꾸는 부분
 */
function changeConfigKeyToParamsKey(config_key_name) {
    const array_key_name = config_key_name.split("_");
    for (var i = 1; i < array_key_name.length; i++) {
        array_key_name[i] = array_key_name[i].substring(0, 1).toUpperCase() + array_key_name[i].substring(1);
    }
    return array_key_name.join("");
}
