<template>
    <div class="MonitoringVueControl">
        <b-button @click="trainModel">훈련 테스트</b-button>
        <b-button @click="showTestView">데이터 보기</b-button>
        <div ref="testview" ></div>
    </div>
</template>
<script>
    /**
     * Sean 수정....
     */
     import * as ncai_data from '@/ncai-core/ncai_data.js' //mnist
    import * as ncai_tutor2 from '@/ncai-core/tutorials/ncai_tutor2.js' //mnist
    import { MnistData } from '@/ncai-core//datasets/mnist/mnistdata.js';
    import * as tf from '@tensorflow/tfjs';
    import * as tfvis from '@tensorflow/tfjs-vis';
    /**
     * ~Sean 수정....
     */



    export default {
        name: 'MonitoringVueControl',
        data() {
            return {}
        },
        methods: {
            async trainModel() {
                console.log("======trainModel=====");
                const data = new MnistData();
                await data.load();
                
                const model = this.getModel();
                tfvis.show.modelSummary({ name: 'Model Architecture', tab: 'Model',drawArea:this.$refs.testview }, model);

                await this.train(model, data);
            },
            async showTestView() {
                console.log("======showTestView=====");
                const data = new MnistData();
                await data.load();


                const numData = 3;
                const [showing_datas_X, showing_datas_Y] = this.get_MnistData(data, numData);
                ncai_data.displayData(this.$refs.testview, showing_datas_X.reshape([numData, 28, 28, 1]), 'img');
                //ncai_data.displayData(this.$refs.testview, showing_datas_Y, 'text');
                ncai_data.displayData(this.$refs.testview, showing_datas_Y, 'argmax');
                //ncai_data.displayData(this.$refs.testview, showing_datas_X.reshape([numData, 28, 28, 1]), 'text');
            },
            get_MnistData(data, num) {
                console.log("===getX_MnistData===");
                const [Xs, Ys] = tf.tidy(() => {
                    const d = data.nextTrainBatch(num);
                    return [d.xs, d.labels];
                });
                return [Xs, Ys];
            },
            getModel() {
                const model = tf.sequential();

                const IMAGE_WIDTH = 28;
                const IMAGE_HEIGHT = 28;
                const IMAGE_CHANNELS = 1;

                // In the first layer of our convolutional neural network we have 
                // to specify the input shape. Then we specify some parameters for 
                // the convolution operation that takes place in this layer.
                model.add(tf.layers.conv2d({
                    inputShape: [IMAGE_WIDTH, IMAGE_HEIGHT, IMAGE_CHANNELS],
                    kernelSize: 5,
                    filters: 8,
                    strides: 1,
                    activation: 'relu',
                    kernelInitializer: 'varianceScaling'
                }));

                // The MaxPooling layer acts as a sort of downsampling using max values
                // in a region instead of averaging.  
                model.add(tf.layers.maxPooling2d({ poolSize: [2, 2], strides: [2, 2] }));

                // Repeat another conv2d + maxPooling stack. 
                // Note that we have more filters in the convolution.
                model.add(tf.layers.conv2d({
                    kernelSize: 5,
                    filters: 16,
                    strides: 1,
                    activation: 'relu',
                    kernelInitializer: 'varianceScaling'
                }));
                model.add(tf.layers.maxPooling2d({ poolSize: [2, 2], strides: [2, 2] }));

                // Now we flatten the output from the 2D filters into a 1D vector to prepare
                // it for input into our last layer. This is common practice when feeding
                // higher dimensional data to a final classification output layer.
                model.add(tf.layers.flatten());

                // Our last layer is a dense layer which has 10 output units, one for each
                // output class (i.e. 0, 1, 2, 3, 4, 5, 6, 7, 8, 9).
                const NUM_OUTPUT_CLASSES = 10;
                model.add(tf.layers.dense({
                    units: NUM_OUTPUT_CLASSES,
                    kernelInitializer: 'varianceScaling',
                    activation: 'softmax'
                }));


                // Choose an optimizer, loss function and accuracy metric,
                // then compile and return the model
                const optimizer = tf.train.adam();
                model.compile({
                    optimizer: optimizer,
                    loss: 'categoricalCrossentropy',
                    metrics: ['accuracy'],
                });

                return model;
            },
            async train(model, data) {
                const metrics = ['loss', 'val_loss', 'acc', 'val_acc'];
                const container = {
                    name: 'Model Training',
                    tab: 'Model',
                    styles: { height: '1000px' }
                };
                
                
                //const fitCallbacks = tfvis.show.fitCallbacks(container, metrics);
                const fitCallbacks = tfvis.show.fitCallbacks(this.$refs.testview, metrics,{zoomToFitAccuracy :true});

                const BATCH_SIZE = 64;
                const TRAIN_DATA_SIZE = 550;
                const TEST_DATA_SIZE = 100;

                const [trainXs, trainYs] = tf.tidy(() => {
                    const d = data.nextTrainBatch(TRAIN_DATA_SIZE);
                    return [
                        d.xs.reshape([TRAIN_DATA_SIZE, 28, 28, 1]),
                        d.labels
                    ];
                });

                const [testXs, testYs] = tf.tidy(() => {
                    const d = data.nextTestBatch(TEST_DATA_SIZE);
                    return [
                        d.xs.reshape([TEST_DATA_SIZE, 28, 28, 1]),
                        d.labels
                    ];
                });

                return model.fit(trainXs, trainYs, {
                    batchSize: BATCH_SIZE,
                    validationData: [testXs, testYs],
                    epochs: 3,
                    shuffle: true,
                    callbacks: fitCallbacks
                });
            }
        }
    }
</script>

<style scoped>
    .MonitoringVueControl {
        color: #555555;
        display: flex;
        flex-direction: column;
        justify-content: space-around;
    }

    .MonitoringVueControl label {
        font-size: 10px !important;
        text-align: left;
        margin-bottom: 5px !important;
        margin-left: 5px !important;
    }

    .MonitoringVueControl select,
    input {
        border-radius: 5px !important;
        width: 180px !important;
    }

    .MonitoringVueControl button {
        margin-top: 5px;
        height: 25px;
        width: 180px;
        font-size: 13px;
        color: #ffffff;
        background: rgba(119, 132, 251, 0.7);
        box-shadow: 0px 8px 15px rgba(0, 0, 0, 0.1);
        text-align: center;
        line-height: 12px;
        border: none;
    }
</style>
