# ------------------------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
# ------------------------------------------------------------------------------------------
# From:
# https://github.com/Azure/MachineLearningNotebooks/blob/master/how-to-use-azureml/ml-frameworks/scikit-learn/train-hyperparameter-tune-deploy-with-sklearn/train_iris.py
import argparse
from pathlib import Path
import numpy as np
from sklearn.metrics import confusion_matrix
from sklearn.model_selection import train_test_split
def main() -> None:
parser = argparse.ArgumentParser()
parser.add_argument('--kernel', type=str, default='linear', help='Kernel type to be used in the algorithm')
parser.add_argument('--penalty', type=float, default=1.0, help='Penalty parameter of the error term')
args = parser.parse_args()
print(f'Kernel type: {args.kernel}')
print(f'Penalty: {args.penalty}')
# X -> features, y -> label
input_folder = Path("dataset")
X = np.loadtxt(fname=input_folder / "X.csv", delimiter=',', skiprows=1)
y = np.loadtxt(fname=input_folder / "y.csv", dtype='str', delimiter=',', skiprows=1)
# dividing X, y into train and test data
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0)
# training a linear SVM classifier
from sklearn.svm import SVC
svm_model_linear = SVC(kernel=args.kernel, C=args.penalty).fit(X_train, y_train)
svm_predictions = svm_model_linear.predict(X_test)
# model accuracy for X_test
accuracy = svm_model_linear.score(X_test, y_test)
print('Accuracy of SVM classifier on test set: {:.2f}'.format(accuracy))
# creating a confusion matrix
cm = confusion_matrix(y_test, svm_predictions)
print(cm)
if __name__ == "__main__":
main()