Skip to content

3. Training models

Jim Schwoebel edited this page Aug 12, 2020 · 9 revisions

Use this folder to train machine learning models according to the default_training_script using the model.py script.

Getting started

All you need to do to get started is go to this repository and run upgrade.py followed by model.py:

cd allie/training
python3 model.py 

You then will be asked a few questions regarding the training process (in terms of data type, number of classes, and the name of the model). Note that --> indicates typed responses.

regression model example

For regression model training, you need to insert a .CSV file for training. You can then specify the target classes here from the spreadsheet and the models will then be trained with the specified model trainers.

To model a regression problem, you need a .CSV file with annotations and files. It is best to produce these files with the annotate.py script, but you can also use datasets created within Allie for regression modeling.

All you need to do is follow the similar steps for modeling, specifying a regression target and the .CSV file of interest:

python3 model.py
                                                                             
                                                                             
               AAA               lllllll lllllll   iiii                      
              A:::A              l:::::l l:::::l  i::::i                     
             A:::::A             l:::::l l:::::l   iiii                      
            A:::::::A            l:::::l l:::::l                             
           A:::::::::A            l::::l  l::::l iiiiiii     eeeeeeeeeeee    
          A:::::A:::::A           l::::l  l::::l i:::::i   ee::::::::::::ee  
         A:::::A A:::::A          l::::l  l::::l  i::::i  e::::::eeeee:::::ee
        A:::::A   A:::::A         l::::l  l::::l  i::::i e::::::e     e:::::e
       A:::::A     A:::::A        l::::l  l::::l  i::::i e:::::::eeeee::::::e
      A:::::AAAAAAAAA:::::A       l::::l  l::::l  i::::i e:::::::::::::::::e 
     A:::::::::::::::::::::A      l::::l  l::::l  i::::i e::::::eeeeeeeeeee  
    A:::::AAAAAAAAAAAAA:::::A     l::::l  l::::l  i::::i e:::::::e           
   A:::::A             A:::::A   l::::::ll::::::li::::::ie::::::::e          
  A:::::A               A:::::A  l::::::ll::::::li::::::i e::::::::eeeeeeee  
 A:::::A                 A:::::A l::::::ll::::::li::::::i  ee:::::::::::::e  
AAAAAAA                   AAAAAAAlllllllllllllllliiiiiiii    eeeeeeeeeeeeee  
                                                                             
                                                                             
                                                                             
                                                                             
                                                                             
                                                                             
                                                                             

is this a classification (c) or regression (r) problem? 
r
what is the name of the spreadsheet (in ./train_dir) used for prediction? 

 available: ['gender_all.csv']

gender_all.csv
how many classes would you like to model? (188 available) 
1
these are the available classes: ['onset_length', 'onset_detect_mean', 'onset_detect_std', 'onset_detect_maxv', 'onset_detect_minv', 'onset_detect_median', 'tempo', 'onset_strength_mean', 'onset_strength_std', 'onset_strength_maxv', 'onset_strength_minv', 'onset_strength_median', 'rhythm_0_mean', 'rhythm_0_std', 'rhythm_0_maxv', 'rhythm_0_minv', 'rhythm_0_median', 'rhythm_1_mean', 'rhythm_1_std', 'rhythm_1_maxv', 'rhythm_1_minv', 'rhythm_1_median', 'rhythm_2_mean', 'rhythm_2_std', 'rhythm_2_maxv', 'rhythm_2_minv', 'rhythm_2_median', 'rhythm_3_mean', 'rhythm_3_std', 'rhythm_3_maxv', 'rhythm_3_minv', 'rhythm_3_median', 'rhythm_4_mean', 'rhythm_4_std', 'rhythm_4_maxv', 'rhythm_4_minv', 'rhythm_4_median', 'rhythm_5_mean', 'rhythm_5_std', 'rhythm_5_maxv', 'rhythm_5_minv', 'rhythm_5_median', 'rhythm_6_mean', 'rhythm_6_std', 'rhythm_6_maxv', 'rhythm_6_minv', 'rhythm_6_median', 'rhythm_7_mean', 'rhythm_7_std', 'rhythm_7_maxv', 'rhythm_7_minv', 'rhythm_7_median', 'rhythm_8_mean', 'rhythm_8_std', 'rhythm_8_maxv', 'rhythm_8_minv', 'rhythm_8_median', 'rhythm_9_mean', 'rhythm_9_std', 'rhythm_9_maxv', 'rhythm_9_minv', 'rhythm_9_median', 'rhythm_10_mean', 'rhythm_10_std', 'rhythm_10_maxv', 'rhythm_10_minv', 'rhythm_10_median', 'rhythm_11_mean', 'rhythm_11_std', 'rhythm_11_maxv', 'rhythm_11_minv', 'rhythm_11_median', 'rhythm_12_mean', 'rhythm_12_std', 'rhythm_12_maxv', 'rhythm_12_minv', 'rhythm_12_median', 'mfcc_0_mean', 'mfcc_0_std', 'mfcc_0_maxv', 'mfcc_0_minv', 'mfcc_0_median', 'mfcc_1_mean', 'mfcc_1_std', 'mfcc_1_maxv', 'mfcc_1_minv', 'mfcc_1_median', 'mfcc_2_mean', 'mfcc_2_std', 'mfcc_2_maxv', 'mfcc_2_minv', 'mfcc_2_median', 'mfcc_3_mean', 'mfcc_3_std', 'mfcc_3_maxv', 'mfcc_3_minv', 'mfcc_3_median', 'mfcc_4_mean', 'mfcc_4_std', 'mfcc_4_maxv', 'mfcc_4_minv', 'mfcc_4_median', 'mfcc_5_mean', 'mfcc_5_std', 'mfcc_5_maxv', 'mfcc_5_minv', 'mfcc_5_median', 'mfcc_6_mean', 'mfcc_6_std', 'mfcc_6_maxv', 'mfcc_6_minv', 'mfcc_6_median', 'mfcc_7_mean', 'mfcc_7_std', 'mfcc_7_maxv', 'mfcc_7_minv', 'mfcc_7_median', 'mfcc_8_mean', 'mfcc_8_std', 'mfcc_8_maxv', 'mfcc_8_minv', 'mfcc_8_median', 'mfcc_9_mean', 'mfcc_9_std', 'mfcc_9_maxv', 'mfcc_9_minv', 'mfcc_9_median', 'mfcc_10_mean', 'mfcc_10_std', 'mfcc_10_maxv', 'mfcc_10_minv', 'mfcc_10_median', 'mfcc_11_mean', 'mfcc_11_std', 'mfcc_11_maxv', 'mfcc_11_minv', 'mfcc_11_median', 'mfcc_12_mean', 'mfcc_12_std', 'mfcc_12_maxv', 'mfcc_12_minv', 'mfcc_12_median', 'poly_0_mean', 'poly_0_std', 'poly_0_maxv', 'poly_0_minv', 'poly_0_median', 'poly_1_mean', 'poly_1_std', 'poly_1_maxv', 'poly_1_minv', 'poly_1_median', 'spectral_centroid_mean', 'spectral_centroid_std', 'spectral_centroid_maxv', 'spectral_centroid_minv', 'spectral_centroid_median', 'spectral_bandwidth_mean', 'spectral_bandwidth_std', 'spectral_bandwidth_maxv', 'spectral_bandwidth_minv', 'spectral_bandwidth_median', 'spectral_contrast_mean', 'spectral_contrast_std', 'spectral_contrast_maxv', 'spectral_contrast_minv', 'spectral_contrast_median', 'spectral_flatness_mean', 'spectral_flatness_std', 'spectral_flatness_maxv', 'spectral_flatness_minv', 'spectral_flatness_median', 'spectral_rolloff_mean', 'spectral_rolloff_std', 'spectral_rolloff_maxv', 'spectral_rolloff_minv', 'spectral_rolloff_median', 'zero_crossings_mean', 'zero_crossings_std', 'zero_crossings_maxv', 'zero_crossings_minv', 'zero_crossings_median', 'RMSE_mean', 'RMSE_std', 'RMSE_maxv', 'RMSE_minv', 'RMSE_median', 'class_']
what is class #1 
class_
what is the 1-word common name for the problem you are working on? (e.g. gender for male/female classification) 
gender

It will then output the regression model in the proper folder (like this), using TPOT as a model trainer. You will also get some awesome stats on the regression modeling sesssion, like in the .JSON file below:

{"sample type": "csv", "created date": "2020-08-03 15:29:43.786976", "device info": {"time": "2020-08-03 15:29", "timezone": ["EST", "EDT"], "operating system": "Darwin", "os release": "19.5.0", "os version": "Darwin Kernel Version 19.5.0: Tue May 26 20:41:44 PDT 2020; root:xnu-6153.121.2~2/RELEASE_X86_64", "cpu data": {"memory": [8589934592, 2577022976, 70.0, 4525428736, 107941888, 2460807168, 2122092544, 2064621568], "cpu percent": 59.1, "cpu times": [22612.18, 0.0, 12992.38, 102624.04], "cpu count": 4, "cpu stats": [110955, 504058, 130337047, 518089], "cpu swap": [2147483648, 1096548352, 1050935296, 51.1, 44743286784, 329093120], "partitions": [["/dev/disk1s6", "/", "apfs", "ro,local,rootfs,dovolfs,journaled,multilabel"], ["/dev/disk1s5", "/System/Volumes/Data", "apfs", "rw,local,dovolfs,dontbrowse,journaled,multilabel"], ["/dev/disk1s4", "/private/var/vm", "apfs", "rw,local,dovolfs,dontbrowse,journaled,multilabel"], ["/dev/disk1s1", "/Volumes/Macintosh HD - Data", "apfs", "rw,local,dovolfs,journaled,multilabel"]], "disk usage": [499963174912, 10985529344, 317145075712, 3.3], "disk io counters": [1689675, 1773144, 52597518336, 34808844288, 1180797, 1136731], "battery": [100, -2, true], "boot time": 1596411904.0}, "space left": 317.145075712}, "session id": "fc54dd66-d5bc-11ea-9c75-acde48001122", "classes": ["class_"], "problem type": "regression", "model name": "gender_tpot_regression.pickle", "model type": "tpot", "metrics": {"mean_absolute_error": 0.37026379788606023, "mean_squared_error": 0.16954440031335424, "median_absolute_error": 0.410668441980656, "r2_score": 0.3199385720764347}, "settings": {"version": "1.0.0", "augment_data": false, "balance_data": true, "clean_data": false, "create_csv": true, "default_audio_augmenters": ["augment_tsaug"], "default_audio_cleaners": ["clean_mono16hz"], "default_audio_features": ["librosa_features"], "default_audio_transcriber": ["deepspeech_dict"], "default_csv_augmenters": ["augment_ctgan_regression"], "default_csv_cleaners": ["clean_csv"], "default_csv_features": ["csv_features"], "default_csv_transcriber": ["raw text"], "default_dimensionality_reducer": ["pca"], "default_feature_selector": ["rfe"], "default_image_augmenters": ["augment_imgaug"], "default_image_cleaners": ["clean_greyscale"], "default_image_features": ["image_features"], "default_image_transcriber": ["tesseract"], "default_outlier_detector": ["isolationforest"], "default_scaler": ["standard_scaler"], "default_text_augmenters": ["augment_textacy"], "default_text_cleaners": ["remove_duplicates"], "default_text_features": ["nltk_features"], "default_text_transcriber": ["raw text"], "default_training_script": ["tpot"], "default_video_augmenters": ["augment_vidaug"], "default_video_cleaners": ["remove_duplicates"], "default_video_features": ["video_features"], "default_video_transcriber": ["tesseract (averaged over frames)"], "dimension_number": 2, "feature_number": 20, "model_compress": false, "reduce_dimensions": false, "remove_outliers": true, "scale_features": false, "select_features": false, "test_size": 0.1, "transcribe_audio": false, "transcribe_csv": true, "transcribe_image": true, "transcribe_text": true, "transcribe_video": true, "transcribe_videos": true, "visualize_data": false, "default_dimensionionality_reducer": ["pca"]}, "transformer name": "", "training data": [], "sample X_test": [30.0, 116.1, 68.390715744171, 224.0, 3.0, 115.5, 129.19921875, 1.579895074162117, 1.4053805862299766, 6.915237601339313, 0.0, 1.1654598038099069, 0.0, 0.0, 0.0, 0.0, 0.0, 0.8033179369485901, 0.00438967342343324, 0.8140795129312649, 0.7979309783326958, 0.80255447893579, 0.5772101965904585, 0.025367026843705915, 0.6147904436358145, 0.5452462503889344, 0.5720709525572024, 0.5251607032640779, 0.031273364291655614, 0.5651684602891733, 0.4833782607526296, 0.522481114581999, 0.53067387207457, 0.01636309315550051, 0.5760527497162795, 0.5083941678429416, 0.5308772078223155, 0.5383483269837346, 0.02398538849569036, 0.6138641187358237, 0.5148823529890311, 0.5317355191905834, 0.5590921868458475, 0.018941050706796927, 0.6185565218733067, 0.5391848127954322, 0.5515129204797803, 0.5653692033981255, 0.022886171192539908, 0.6170498591126126, 0.5187020777516459, 0.5693268285980656, 0.5428369240411614, 0.011543007163874491, 0.5837123204211986, 0.5208221399174541, 0.5415414663324902, 0.4946660644711973, 0.021472694373470352, 0.5215764169994959, 0.4640787039752625, 0.4952267598817138, 0.4798469011394895, 0.02593484469896265, 0.5172960598832023, 0.4449712627305569, 0.4777149108114186, 0.4993938744598669, 0.01849048457494309, 0.5651910299787914, 0.4822436630327371, 0.4950261489562563, 0.5363930497563161, 0.0376443504751349, 0.6330907702118795, 0.4816294954352716, 0.5249507027509328, -235.4678661326307, 61.51638081120653, -119.29458629496251, -362.1632462796749, -227.60500825042942, 163.92070611988834, 47.05955903012367, 237.9764586528294, 41.986380826321785, 172.71493170004138, 9.237411399943188, 25.868443694231683, 61.477039729510096, -75.39528620218707, 9.629797757209056, 38.85787728431835, 25.651975918739637, 120.33667371104372, -9.003575689525233, 36.13886469019118, -3.813926397129359, 18.466559976322753, 45.395818864794386, -54.58126572108478, -3.563646356257889, 28.49882430361086, 15.286105184256387, 72.2886732962803, 0.03239718043784112, 26.491533722920998, -19.866746887564343, 16.46528562102129, 9.928420130258688, -61.42422346209003, -17.134010559191154, 4.917765483447672, 13.106589177321654, 36.30054941946764, -28.88492762419697, 4.470641784765922, -7.5214435695300805, 11.456845078656613, 24.68530842159717, -33.23468909518539, -7.800944005694487, 1.7653313822916499, 10.137823325108423, 26.38688279047729, -22.507646864346647, 2.1230603462314384, 2.9722994596741263, 9.920580299259306, 29.09083383516883, -28.462312178142557, 3.1356694281534625, -8.31659816437322, 9.321735116288234, 14.977416272339756, -29.19924207526083, -7.200232618719922, 10.020856138237773, 9.605360863583002, 33.70453001221575, -10.34310153320585, 8.538943192527702, -0.0003117740953455404, 0.0002093530273784296, -3.649852038234921e-05, -0.0008609846033373115, -0.00024944132582088046, 2.427670449088513, 1.573081810523066, 6.574603060966783, 0.2961628052414745, 1.8991203106986, 1122.5579040699354, 895.7957759390358, 4590.354474064802, 349.53842801686966, 800.0437543350607, 1384.7323846043691, 519.4846094956321, 2642.151716668925, 703.4646482237979, 1229.7584170111122, 22.097758701059746, 6.005214057147793, 54.922406822231686, 8.895233246285754, 22.047151155860252, 6.146541272755712e-05, 0.00013457647582981735, 0.0006881643203087151, 5.692067475138174e-07, 8.736528798181098e-06, 2087.3572470319323, 1731.5818839146564, 6535.3271484375, 409.130859375, 1421.19140625, 0.05515445892467249, 0.07680443213522453, 0.46142578125, 0.0078125, 0.0302734375, 0.12412750720977785, 0.07253565639257431, 0.29952874779701233, 0.010528072714805605, 0.10663044452667236], "sample y_test": 0}

The resulting model will have the following data:

└── gender_tpot_regression
    ├── data
    │   ├── gender_all.csv
    │   ├── gender_all_transformed.csv
    │   ├── gender_test.csv
    │   ├── gender_test_transformed.csv
    │   ├── gender_train.csv
    │   └── gender_train_transformed.csv
    ├── model
    │   ├── bar_graph_predictions.png
    │   ├── gender_tpot_regression.json
    │   ├── gender_tpot_regression.pickle
    │   ├── gender_tpot_regression.py
    │   └── gender_tpot_regression_transform.pickle
    ├── readme.md
    ├── requirements.txt
    └── settings.json

Click the .GIF below to follow along this example in a video format:

classification model example

To now model both males and females as a binary gender classification problem, type this into the terminal:

cd /Users/jim/desktop/allie
cd training
python3 model.py

The resulting output in the terminal will be something like:

                                                                             
                                                                             
               AAA               lllllll lllllll   iiii                      
              A:::A              l:::::l l:::::l  i::::i                     
             A:::::A             l:::::l l:::::l   iiii                      
            A:::::::A            l:::::l l:::::l                             
           A:::::::::A            l::::l  l::::l iiiiiii     eeeeeeeeeeee    
          A:::::A:::::A           l::::l  l::::l i:::::i   ee::::::::::::ee  
         A:::::A A:::::A          l::::l  l::::l  i::::i  e::::::eeeee:::::ee
        A:::::A   A:::::A         l::::l  l::::l  i::::i e::::::e     e:::::e
       A:::::A     A:::::A        l::::l  l::::l  i::::i e:::::::eeeee::::::e
      A:::::AAAAAAAAA:::::A       l::::l  l::::l  i::::i e:::::::::::::::::e 
     A:::::::::::::::::::::A      l::::l  l::::l  i::::i e::::::eeeeeeeeeee  
    A:::::AAAAAAAAAAAAA:::::A     l::::l  l::::l  i::::i e:::::::e           
   A:::::A             A:::::A   l::::::ll::::::li::::::ie::::::::e          
  A:::::A               A:::::A  l::::::ll::::::li::::::i e::::::::eeeeeeee  
 A:::::A                 A:::::A l::::::ll::::::li::::::i  ee:::::::::::::e  
AAAAAAA                   AAAAAAAlllllllllllllllliiiiiiii    eeeeeeeeeeeeee  
                                                                             
                                                                             
                                                                             
                                                                             
                                                                             
                                                                             
                                                                             

is this a classification (c) or regression (r) problem? 
c
what problem are you solving? (1-audio, 2-text, 3-image, 4-video, 5-csv)
1

 OK cool, we got you modeling audio files 

how many classes would you like to model? (2 available) 
2
these are the available classes: 
['females', 'males']
what is class #1 
males
what is class #2 
females
what is the 1-word common name for the problem you are working on? (e.g. gender for male/female classification) 
gender
-----------------------------------
          LOADING MODULES          
-----------------------------------
Requirement already satisfied: scikit-learn==0.22.2.post1 in /usr/local/lib/python3.7/site-packages (0.22.2.post1)
Requirement already satisfied: numpy>=1.11.0 in /usr/local/lib/python3.7/site-packages (from scikit-learn==0.22.2.post1) (1.18.4)
Requirement already satisfied: scipy>=0.17.0 in /usr/local/lib/python3.7/site-packages (from scikit-learn==0.22.2.post1) (1.4.1)
Requirement already satisfied: joblib>=0.11 in /usr/local/lib/python3.7/site-packages (from scikit-learn==0.22.2.post1) (0.15.1)
-----------------------------------
______ _____  ___ _____ _   _______ _____ ___________ _   _ _____ 
|  ___|  ___|/ _ \_   _| | | | ___ \_   _|___  /_   _| \ | |  __ \
| |_  | |__ / /_\ \| | | | | | |_/ / | |    / /  | | |  \| | |  \/
|  _| |  __||  _  || | | | | |    /  | |   / /   | | | . ` | | __ 
| |   | |___| | | || | | |_| | |\ \ _| |_./ /____| |_| |\  | |_\ \
\_|   \____/\_| |_/\_/  \___/\_| \_|\___/\_____/\___/\_| \_/\____/
                                                                  
                                                                  
______  ___ _____ ___  
|  _  \/ _ \_   _/ _ \ 
| | | / /_\ \| |/ /_\ \
| | | |  _  || ||  _  |
| |/ /| | | || || | | |
|___/ \_| |_/\_/\_| |_/
                       
                       

-----------------------------------
-----------------------------------
           FEATURIZING MALES
-----------------------------------
males: 100%|█████████████████████████████████| 204/204 [00:00<00:00, 432.04it/s]
-----------------------------------
           FEATURIZING FEMALES
-----------------------------------
females: 100%|███████████████████████████████| 204/204 [00:00<00:00, 792.07it/s]
-----------------------------------
 _____ ______ _____  ___ _____ _____ _   _ _____ 
/  __ \| ___ \  ___|/ _ \_   _|_   _| \ | |  __ \
| /  \/| |_/ / |__ / /_\ \| |   | | |  \| | |  \/
| |    |    /|  __||  _  || |   | | | . ` | | __ 
| \__/\| |\ \| |___| | | || |  _| |_| |\  | |_\ \
 \____/\_| \_\____/\_| |_/\_/  \___/\_| \_/\____/
                                                 
                                                 
 ___________  ___  _____ _   _ _____ _   _ _____  ______  ___ _____ ___  
|_   _| ___ \/ _ \|_   _| \ | |_   _| \ | |  __ \ |  _  \/ _ \_   _/ _ \ 
  | | | |_/ / /_\ \ | | |  \| | | | |  \| | |  \/ | | | / /_\ \| |/ /_\ \
  | | |    /|  _  | | | | . ` | | | | . ` | | __  | | | |  _  || ||  _  |
  | | | |\ \| | | |_| |_| |\  |_| |_| |\  | |_\ \ | |/ /| | | || || | | |
  \_/ \_| \_\_| |_/\___/\_| \_/\___/\_| \_/\____/ |___/ \_| |_/\_/\_| |_/
                                                                         
                                                                         

-----------------------------------
-----------------------------------
			REMOVING OUTLIERS
-----------------------------------
<class 'list'>
<class 'int'>
193
11
(204, 187)
(204,)
(193, 187)
(193,)
males greater than minlength (94) by 5, equalizing...
males greater than minlength (94) by 4, equalizing...
males greater than minlength (94) by 3, equalizing...
males greater than minlength (94) by 2, equalizing...
males greater than minlength (94) by 1, equalizing...
males greater than minlength (94) by 0, equalizing...
gender_ALL.CSV
gender_TRAIN.CSV
gender_TEST.CSV
----------------------------------
 ___________  ___   _   _  ___________ ______________  ________ _   _ _____ 
|_   _| ___ \/ _ \ | \ | |/  ___|  ___|  _  | ___ \  \/  |_   _| \ | |  __ \
  | | | |_/ / /_\ \|  \| |\ `--.| |_  | | | | |_/ / .  . | | | |  \| | |  \/
  | | |    /|  _  || . ` | `--. \  _| | | | |    /| |\/| | | | | . ` | | __ 
  | | | |\ \| | | || |\  |/\__/ / |   \ \_/ / |\ \| |  | |_| |_| |\  | |_\ \
  \_/ \_| \_\_| |_/\_| \_/\____/\_|    \___/\_| \_\_|  |_/\___/\_| \_/\____/
                                                                            
                                                                            
______  ___ _____ ___  
|  _  \/ _ \_   _/ _ \ 
| | | / /_\ \| |/ /_\ \
| | | |  _  || ||  _  |
| |/ /| | | || || | | |
|___/ \_| |_/\_/\_| |_/
                       
                       

----------------------------------
Requirement already satisfied: scikit-learn==0.22.2.post1 in /usr/local/lib/python3.7/site-packages (0.22.2.post1)
Requirement already satisfied: joblib>=0.11 in /usr/local/lib/python3.7/site-packages (from scikit-learn==0.22.2.post1) (0.15.1)
Requirement already satisfied: scipy>=0.17.0 in /usr/local/lib/python3.7/site-packages (from scikit-learn==0.22.2.post1) (1.4.1)
Requirement already satisfied: numpy>=1.11.0 in /usr/local/lib/python3.7/site-packages (from scikit-learn==0.22.2.post1) (1.18.4)
making transformer...
python3 transform.py audio c gender  males females
Requirement already satisfied: scikit-learn==0.22.2.post1 in /usr/local/lib/python3.7/site-packages (0.22.2.post1)
Requirement already satisfied: scipy>=0.17.0 in /usr/local/lib/python3.7/site-packages (from scikit-learn==0.22.2.post1) (1.4.1)
Requirement already satisfied: joblib>=0.11 in /usr/local/lib/python3.7/site-packages (from scikit-learn==0.22.2.post1) (0.15.1)
Requirement already satisfied: numpy>=1.11.0 in /usr/local/lib/python3.7/site-packages (from scikit-learn==0.22.2.post1) (1.18.4)
/Users/jim/Desktop/allie
True
False
True
['standard_scaler']
['pca']
['rfe']
['males']
['males', 'females']
----------LOADING MALES----------
100%|███████████████████████████████████████| 102/102 [00:00<00:00, 3633.53it/s]
----------LOADING FEMALES----------
100%|███████████████████████████████████████| 102/102 [00:00<00:00, 3253.67it/s]
[29.0, 115.48275862068965, 61.08593194873256, 222.0, 8.0, 115.0, 135.99917763157896, 2.3062946170604843, 2.4523724313147683, 14.511204587940831, 0.0, 1.5839709225558054, 1.0, 0.0, 1.0, 1.0, 1.0, 0.726785565165924, 0.07679976746065657, 0.8995575540236156, 0.6637538362298933, 0.6773691100707808, 0.5192872926710033, 0.14462205416039903, 0.8356997271918719, 0.39185128881625986, 0.43825294226507827, 0.47576864823687653, 0.15489798102781285, 0.8172667015988133, 0.33745619224839324, 0.3880062539044845, 0.5039810096886804, 0.14210116297622033, 0.8222246027925175, 0.3891757588007075, 0.41656986998926515, 0.5373808945288242, 0.13048669284410083, 0.8278757499272521, 0.42749020156019385, 0.4603386494395968, 0.576968162842806, 0.12481555452009631, 0.8326178822159916, 0.4151756068434901, 0.5352195483012459, 0.5831458034458475, 0.12993864932034976, 0.8408276288260831, 0.4086771989157548, 0.5442120810410158, 0.5991899975086484, 0.12087077223741394, 0.8512938716671109, 0.4582549119597623, 0.548583257282575, 0.6043430612098373, 0.0967011928572277, 0.833318501846635, 0.5289767977628144, 0.5386694936622187, 0.5901015163274542, 0.09036115342542611, 0.8132908481041344, 0.5183061555498681, 0.5348989197039723, 0.5577690938261909, 0.10440904667022494, 0.8051028496830286, 0.47654504571917944, 0.48986432314249784, 0.5222194992298572, 0.12036460062379664, 0.7973796198667512, 0.4203428200207039, 0.4509165396076157, -324.9346584335899, 63.40977505468484, -239.1113566990693, -536.0134915419532, -312.1029273499763, 158.69053618273082, 35.26991926915998, 224.33444977809518, 25.32709426551922, 164.09725638069276, -41.76686834014897, 36.193666229738035, 46.89516424897592, -116.64411005852219, -48.46812935048629, 40.93271834758351, 30.292873365157128, 110.30488966414437, -34.8296992058053, 40.54577852540313, -13.68074476738829, 23.578831857611142, 44.5288579949288, -81.96856948352185, -13.20824575924119, 28.01017666759282, 19.911510776447017, 79.48989729266256, -30.98446467042396, 27.506651161152135, -26.204186150756332, 16.325928650509297, 17.379122853402333, -64.23824041845967, -27.70833256772887, 14.638824890367118, 14.030449436317777, 49.746826625863726, -24.54064068297873, 12.937758655225592, 1.8907564192378423, 12.717800756091705, 32.81143480306558, -34.17480821652823, 3.017798387908008, -9.890548990017422, 11.275154613335049, 16.431256434502732, -41.48821773570883, -10.025098347722025, -2.238522066589343, 11.50921922011025, 25.053143314110734, -36.57309603680529, -2.0110753582118464, -11.28338110558961, 10.092676771445209, 12.359297810934656, -45.39044308667263, -9.744274595029339, 6.634597233918086, 8.23910310866827, 31.12846300160725, -11.374600658849563, 6.9929843274651455, -9.01995244948718e-05, 4.4746520830831e-05, -9.463474948151087e-06, -0.00017975655569362465, -9.088812250730835e-05, 0.7226232345166703, 0.3516279383632571, 1.4144159675282353, 0.082382838783397, 0.7190915579267225, 1451.143003087302, 642.4123068137746, 4547.085433482971, 395.8091024437681, 1324.495426192924, 1453.4773788957211, 426.3531685791114, 2669.3033744664745, 747.1557882330682, 1339.431286902565, 17.70035758898037, 4.253697620372516, 33.22776254607448, 8.885816697236287, 17.70216728565277, 0.00011313906725263223, 0.00018414952501188964, 0.001124512287788093, 5.7439488045929465e-06, 4.929980423185043e-05, 2670.3094482421875, 1335.5639439562065, 6836.7919921875, 355.2978515625, 2282.51953125, 0.07254682268415179, 0.04210112258843493, 0.27783203125, 0.01513671875, 0.062255859375, 0.028742285445332527, 0.011032973416149616, 0.05006047338247299, 0.005435430910438299, 0.029380839318037033]
['onset_length', 'onset_detect_mean', 'onset_detect_std', 'onset_detect_maxv', 'onset_detect_minv', 'onset_detect_median', 'tempo', 'onset_strength_mean', 'onset_strength_std', 'onset_strength_maxv', 'onset_strength_minv', 'onset_strength_median', 'rhythm_0_mean', 'rhythm_0_std', 'rhythm_0_maxv', 'rhythm_0_minv', 'rhythm_0_median', 'rhythm_1_mean', 'rhythm_1_std', 'rhythm_1_maxv', 'rhythm_1_minv', 'rhythm_1_median', 'rhythm_2_mean', 'rhythm_2_std', 'rhythm_2_maxv', 'rhythm_2_minv', 'rhythm_2_median', 'rhythm_3_mean', 'rhythm_3_std', 'rhythm_3_maxv', 'rhythm_3_minv', 'rhythm_3_median', 'rhythm_4_mean', 'rhythm_4_std', 'rhythm_4_maxv', 'rhythm_4_minv', 'rhythm_4_median', 'rhythm_5_mean', 'rhythm_5_std', 'rhythm_5_maxv', 'rhythm_5_minv', 'rhythm_5_median', 'rhythm_6_mean', 'rhythm_6_std', 'rhythm_6_maxv', 'rhythm_6_minv', 'rhythm_6_median', 'rhythm_7_mean', 'rhythm_7_std', 'rhythm_7_maxv', 'rhythm_7_minv', 'rhythm_7_median', 'rhythm_8_mean', 'rhythm_8_std', 'rhythm_8_maxv', 'rhythm_8_minv', 'rhythm_8_median', 'rhythm_9_mean', 'rhythm_9_std', 'rhythm_9_maxv', 'rhythm_9_minv', 'rhythm_9_median', 'rhythm_10_mean', 'rhythm_10_std', 'rhythm_10_maxv', 'rhythm_10_minv', 'rhythm_10_median', 'rhythm_11_mean', 'rhythm_11_std', 'rhythm_11_maxv', 'rhythm_11_minv', 'rhythm_11_median', 'rhythm_12_mean', 'rhythm_12_std', 'rhythm_12_maxv', 'rhythm_12_minv', 'rhythm_12_median', 'mfcc_0_mean', 'mfcc_0_std', 'mfcc_0_maxv', 'mfcc_0_minv', 'mfcc_0_median', 'mfcc_1_mean', 'mfcc_1_std', 'mfcc_1_maxv', 'mfcc_1_minv', 'mfcc_1_median', 'mfcc_2_mean', 'mfcc_2_std', 'mfcc_2_maxv', 'mfcc_2_minv', 'mfcc_2_median', 'mfcc_3_mean', 'mfcc_3_std', 'mfcc_3_maxv', 'mfcc_3_minv', 'mfcc_3_median', 'mfcc_4_mean', 'mfcc_4_std', 'mfcc_4_maxv', 'mfcc_4_minv', 'mfcc_4_median', 'mfcc_5_mean', 'mfcc_5_std', 'mfcc_5_maxv', 'mfcc_5_minv', 'mfcc_5_median', 'mfcc_6_mean', 'mfcc_6_std', 'mfcc_6_maxv', 'mfcc_6_minv', 'mfcc_6_median', 'mfcc_7_mean', 'mfcc_7_std', 'mfcc_7_maxv', 'mfcc_7_minv', 'mfcc_7_median', 'mfcc_8_mean', 'mfcc_8_std', 'mfcc_8_maxv', 'mfcc_8_minv', 'mfcc_8_median', 'mfcc_9_mean', 'mfcc_9_std', 'mfcc_9_maxv', 'mfcc_9_minv', 'mfcc_9_median', 'mfcc_10_mean', 'mfcc_10_std', 'mfcc_10_maxv', 'mfcc_10_minv', 'mfcc_10_median', 'mfcc_11_mean', 'mfcc_11_std', 'mfcc_11_maxv', 'mfcc_11_minv', 'mfcc_11_median', 'mfcc_12_mean', 'mfcc_12_std', 'mfcc_12_maxv', 'mfcc_12_minv', 'mfcc_12_median', 'poly_0_mean', 'poly_0_std', 'poly_0_maxv', 'poly_0_minv', 'poly_0_median', 'poly_1_mean', 'poly_1_std', 'poly_1_maxv', 'poly_1_minv', 'poly_1_median', 'spectral_centroid_mean', 'spectral_centroid_std', 'spectral_centroid_maxv', 'spectral_centroid_minv', 'spectral_centroid_median', 'spectral_bandwidth_mean', 'spectral_bandwidth_std', 'spectral_bandwidth_maxv', 'spectral_bandwidth_minv', 'spectral_bandwidth_median', 'spectral_contrast_mean', 'spectral_contrast_std', 'spectral_contrast_maxv', 'spectral_contrast_minv', 'spectral_contrast_median', 'spectral_flatness_mean', 'spectral_flatness_std', 'spectral_flatness_maxv', 'spectral_flatness_minv', 'spectral_flatness_median', 'spectral_rolloff_mean', 'spectral_rolloff_std', 'spectral_rolloff_maxv', 'spectral_rolloff_minv', 'spectral_rolloff_median', 'zero_crossings_mean', 'zero_crossings_std', 'zero_crossings_maxv', 'zero_crossings_minv', 'zero_crossings_median', 'RMSE_mean', 'RMSE_std', 'RMSE_maxv', 'RMSE_minv', 'RMSE_median']
STANDARD_SCALER
RFE - 20 features
[('standard_scaler', StandardScaler(copy=True, with_mean=True, with_std=True)), ('rfe', RFE(estimator=SVR(C=1.0, cache_size=200, coef0=0.0, degree=3, epsilon=0.1,
                  gamma='scale', kernel='linear', max_iter=-1, shrinking=True,
                  tol=0.001, verbose=False),
    n_features_to_select=20, step=1, verbose=0))]
21
21
transformed training size
[ 1.40339001 -0.46826787 -1.14495583  0.16206821  0.29026766  1.12188432
  0.60680057  0.89387526  1.32137089  0.00311893  0.90967555  1.64556559
  0.29214439  0.47100284 -0.3969793   0.31780256 -0.54840539 -0.34677929
 -0.69461853 -0.60881832]
/Users/jim/Desktop/allie/preprocessing
c_gender_standard_scaler_rfe.pickle
----------------------------------
-%-$-V-|-%-$-V-|-%-$-V-|-%-$-V-|-%-$-
         TRANSFORMATION           
-%-$-V-|-%-$-V-|-%-$-V-|-%-$-V-|-%-$-
----------------------------------
[7.0, 23.428571428571427, 13.275725891987362, 40.0, 3.0, 29.0, 143.5546875, 0.9958894161283733, 0.548284548384031, 2.9561698164853145, 0.0, 0.9823586435889371, 1.0, 0.0, 1.0, 1.0, 1.0, 0.9563696862586734, 0.004556745090440225, 0.96393999788138, 0.9484953491224131, 0.956450500708107, 0.9159480905407004, 0.0083874210259623, 0.9299425651546245, 0.9015187648369452, 0.9160637212676059, 0.8944452328439548, 0.010436308983995553, 0.9118600181268972, 0.876491746933975, 0.8945884359248597, 0.8770112296385062, 0.012213153676505805, 0.897368483560061, 0.8559758923011408, 0.8771914323071244, 0.8905940594153822, 0.010861457873386392, 0.9086706269489068, 0.871855264076524, 0.8907699627836017, 0.8986618946665619, 0.010095245635220757, 0.9154303336418514, 0.8812084302562171, 0.8988437925782, 0.8940183738781617, 0.01039025341119034, 0.9112831285473728, 0.8760607413483972, 0.8942023162840362, 0.8980899592148434, 0.010095670257151154, 0.9148198685317376, 0.8805923053101389, 0.8982937440953425, 0.888882132228931, 0.01127171631616655, 0.9075145588083179, 0.8692983803153955, 0.8891346978561219, 0.8803505690032647, 0.012156972724147449, 0.9004327206206673, 0.8592151658555779, 0.8806302144175665, 0.8783421648443998, 0.012494808554290946, 0.8989303042965345, 0.8565635990020571, 0.8786581988736766, 0.8633477274070439, 0.013675039594980561, 0.8859283087034104, 0.8395630641775639, 0.8636673866109066, -313.23363896251993, 18.946764320029068, -265.4153881359699, -352.35669009191434, -314.88964335810385, 136.239475379525, 13.46457033057532, 155.28229790095634, 96.67729067600845, 138.31307847807975, -60.109940589659594, 8.90651546650511, -43.00250224341745, -77.22644879310883, -61.59614027580888, 59.959525194997426, 11.49266912690683, 79.92823038661382, 22.593262641790204, 60.14384367187341, -54.39960148805922, 12.978670142489454, -16.69391321594054, -78.18044376664089, -54.04351001558572, 30.023862498118685, 8.714431771268103, 45.984861607171624, -7.969418151448695, 30.779899533210106, -48.79826737490987, 9.404798307829793, -17.32746858770041, -67.85565811008664, -48.67558954047166, 16.438960903373093, 6.676108733267705, 24.75100641115554, -1.8759098025429237, 18.27300445180957, -24.239093865617573, 6.8313516276284245, -9.56759656295116, -40.92277771655667, -24.18878158134608, 3.2516761928923215, 4.2430222382933085, 10.37732827872848, -6.461490621772226, 3.393567465008272, -4.1570109920127685, 5.605424304597271, 5.78957218995748, -18.10767695295411, -3.8190369770110664, -9.46159588572396, 5.81772077466229, 2.7763746636679323, -20.054279810217025, -10.268401482915364, 9.197482271105386, 5.755721680320874, 18.46922506683798, -6.706210697210241, 10.044558505805792, -4.748126927006937e-05, 1.0575334143974938e-05, -2.4722594252240076e-05, -6.952111317908028e-05, -4.773507820227446e-05, 0.40672100602206257, 0.08467855992898438, 0.5757090803234001, 0.22579515457526012, 0.4087367660373401, 2210.951610581014, 212.91019021101542, 2791.2529330926723, 1845.3115106685345, 2223.07457835522, 2063.550185470081, 111.14828141425747, 2287.23164419421, 1816.298268701022, 2073.585928819859, 12.485818860644423, 4.014563343823625, 25.591622605877692, 7.328069768561837, 11.33830589622713, 0.0021384726278483868, 0.004675689153373241, 0.020496303215622902, 0.00027283065719529986, 0.0006564159411936998, 4814.383766867898, 483.0045387722584, 5857.03125, 3682.177734375, 4888.037109375, 0.12172629616477272, 0.0227615872259035, 0.1875, 0.0732421875, 0.1181640625, 0.011859399266541004, 0.0020985447335988283, 0.015743320807814598, 0.006857675965875387, 0.012092416174709797]
-->
[[-1.08260965 -0.98269388 -0.60797492 -0.75483856 -0.81280646 -0.89654763
  -0.2878008  -0.57018752  0.31999349  0.91470661 -0.79709927 -0.39215548
  -0.52523377  0.54936626 -0.85596512  0.88348636  0.96310551  0.00975297
   1.56752723 -0.81022666]]
----------------------------------
gender_ALL_TRANSFORMED.CSV
converting csv...: 100%|█████████████████████| 187/187 [00:00<00:00, 570.70it/s]
     transformed_feature_0  ...  class_
0                -1.109610  ...       0
1                 1.125489  ...       0
2                -1.496244  ...       0
3                -0.971811  ...       0
4                -0.961863  ...       0
..                     ...  ...     ...
182              -1.358248  ...       1
183              -0.376670  ...       1
184              -0.840383  ...       1
185              -0.662551  ...       1
186              -0.580909  ...       1

[187 rows x 21 columns]
writing csv file...
gender_TRAIN_TRANSFORMED.CSV
converting csv...: 100%|█████████████████████| 168/168 [00:00<00:00, 472.53it/s]
     transformed_feature_0  ...  class_
0                 1.378101  ...       0
1                -0.866300  ...       1
2                 1.860016  ...       1
3                -0.124242  ...       1
4                 1.015606  ...       1
..                     ...  ...     ...
163               1.151959  ...       1
164              -0.157868  ...       1
165              -1.179480  ...       0
166              -0.376670  ...       1
167              -0.580909  ...       1

[168 rows x 21 columns]
writing csv file...
gender_TEST_TRANSFORMED.CSV
converting csv...: 100%|███████████████████████| 19/19 [00:00<00:00, 419.95it/s]
    transformed_feature_0  ...  class_
0                0.194916  ...       1
1               -0.818428  ...       0
2               -0.089688  ...       1
3               -0.432771  ...       0
4               -0.457341  ...       0
5               -1.054777  ...       1
6               -0.458814  ...       0
7               -1.011486  ...       1
8                1.028686  ...       1
9                0.753718  ...       1
10              -0.959699  ...       0
11               0.515472  ...       1
12              -1.007477  ...       0
13              -0.552799  ...       1
14              -0.642334  ...       1
15               1.778480  ...       1
16              -0.444939  ...       0
17               0.939541  ...       0
18               1.460974  ...       0

[19 rows x 21 columns]
writing csv file...
----------------------------------
___  ______________ _____ _     _____ _   _ _____  ______  ___ _____ ___  
|  \/  |  _  |  _  \  ___| |   |_   _| \ | |  __ \ |  _  \/ _ \_   _/ _ \ 
| .  . | | | | | | | |__ | |     | | |  \| | |  \/ | | | / /_\ \| |/ /_\ \
| |\/| | | | | | | |  __|| |     | | | . ` | | __  | | | |  _  || ||  _  |
| |  | \ \_/ / |/ /| |___| |_____| |_| |\  | |_\ \ | |/ /| | | || || | | |
\_|  |_/\___/|___/ \____/\_____/\___/\_| \_/\____/ |___/ \_| |_/\_/\_| |_/
                                                                          
                                                                          

----------------------------------
tpot:   0%|                                               | 0/1 [00:00<?, ?it/s]----------------------------------
       .... training TPOT           
----------------------------------
Requirement already satisfied: tpot==0.11.3 in /usr/local/lib/python3.7/site-packages (0.11.3)
Requirement already satisfied: stopit>=1.1.1 in /usr/local/lib/python3.7/site-packages (from tpot==0.11.3) (1.1.2)
Requirement already satisfied: pandas>=0.24.2 in /usr/local/lib/python3.7/site-packages (from tpot==0.11.3) (0.25.3)
Requirement already satisfied: update-checker>=0.16 in /usr/local/lib/python3.7/site-packages (from tpot==0.11.3) (0.17)
Requirement already satisfied: numpy>=1.16.3 in /usr/local/lib/python3.7/site-packages (from tpot==0.11.3) (1.18.4)
Requirement already satisfied: scipy>=1.3.1 in /usr/local/lib/python3.7/site-packages (from tpot==0.11.3) (1.4.1)
Requirement already satisfied: tqdm>=4.36.1 in /usr/local/lib/python3.7/site-packages (from tpot==0.11.3) (4.43.0)
Requirement already satisfied: deap>=1.2 in /usr/local/lib/python3.7/site-packages (from tpot==0.11.3) (1.3.1)
Requirement already satisfied: scikit-learn>=0.22.0 in /usr/local/lib/python3.7/site-packages (from tpot==0.11.3) (0.22.2.post1)
Requirement already satisfied: joblib>=0.13.2 in /usr/local/lib/python3.7/site-packages (from tpot==0.11.3) (0.15.1)
Requirement already satisfied: python-dateutil>=2.6.1 in /usr/local/lib/python3.7/site-packages (from pandas>=0.24.2->tpot==0.11.3) (2.8.1)
Requirement already satisfied: pytz>=2017.2 in /usr/local/lib/python3.7/site-packages (from pandas>=0.24.2->tpot==0.11.3) (2020.1)
Requirement already satisfied: requests>=2.3.0 in /usr/local/lib/python3.7/site-packages (from update-checker>=0.16->tpot==0.11.3) (2.24.0)
Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.7/site-packages (from python-dateutil>=2.6.1->pandas>=0.24.2->tpot==0.11.3) (1.15.0)
Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.7/site-packages (from requests>=2.3.0->update-checker>=0.16->tpot==0.11.3) (2020.4.5.2)
Requirement already satisfied: chardet<4,>=3.0.2 in /usr/local/lib/python3.7/site-packages (from requests>=2.3.0->update-checker>=0.16->tpot==0.11.3) (3.0.4)
Requirement already satisfied: idna<3,>=2.5 in /usr/local/lib/python3.7/site-packages (from requests>=2.3.0->update-checker>=0.16->tpot==0.11.3) (2.9)
Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /usr/local/lib/python3.7/site-packages (from requests>=2.3.0->update-checker>=0.16->tpot==0.11.3) (1.25.9)
Warning: xgboost.XGBClassifier is not available and will not be used by TPOT.
Generation 1 - Current best internal CV score: 0.7976827094474153               
Generation 2 - Current best internal CV score: 0.8096256684491978               
Generation 3 - Current best internal CV score: 0.8096256684491978               
Generation 4 - Current best internal CV score: 0.8096256684491978               
Generation 5 - Current best internal CV score: 0.8096256684491978               
Generation 6 - Current best internal CV score: 0.8215686274509804               
Generation 7 - Current best internal CV score: 0.8215686274509804               
Generation 8 - Current best internal CV score: 0.8219251336898395               
Generation 9 - Current best internal CV score: 0.8219251336898395               
Generation 10 - Current best internal CV score: 0.8276292335115866              
tpot:   0%|                                               | 0/1 [03:58<?, ?it/s]
Best pipeline: LinearSVC(Normalizer(input_matrix, norm=max), C=20.0, dual=True, loss=hinge, penalty=l2, tol=0.0001)

/usr/local/lib/python3.7/site-packages/sklearn/svm/_base.py:947: ConvergenceWarning: Liblinear failed to converge, increase the number of iterations.
  "the number of iterations.", ConvergenceWarning)
saving classifier to disk
[1 0 1 0 0 1 0 1 1 0 0 1 1 1 1 1 0 0 0]

Normalized confusion matrix
error making y_probas
error plotting ROC curve
predict_proba only works for or log loss and modified Huber loss.
tpot: 100%|██████████████████████████████████████| 1/1 [04:08<00:00, 248.61s/it]

The result will be a GitHub repo like this, defining the model session and summary. Accuracy metrics will be defined as part of the model training process:

{'accuracy': 0.8947368421052632, 'balanced_accuracy': 0.8944444444444444, 'precision': 0.9, 'recall': 0.9, 'f1_score': 0.9, 'f1_micro': 0.8947368421052632, 'f1_macro': 0.8944444444444444, 'roc_auc': 0.8944444444444444, 'roc_auc_micro': 0.8944444444444444, 'roc_auc_macro': 0.8944444444444444, 'confusion_matrix': [[8, 1], [1, 9]], 'classification_report': '              precision    recall  f1-score   support\n\n       males       0.89      0.89      0.89         9\n     females       0.90      0.90      0.90        10\n\n    accuracy                           0.89        19\n   macro avg       0.89      0.89      0.89        19\nweighted avg       0.89      0.89      0.89        19\n'}

Click the .GIF below to follow along this example in a video format:

After this, the model will be trained and placed in the models/[sampletype_models] directory. For example, if you trained an audio model with TPOT, the model will be placed in the allie/models/audio_models/ directory.

For automated training, you can alternatively pass through sys.argv[] inputs as follows:

python3 model.py audio 2 c gender males females

Where:

  • audio = audio file type
  • 2 = 2 classes
  • c = classification (r for regression)
  • gender = common name of model
  • male = first class
  • female = second class [via N number of classes]

The goal is to make an output folder like this:

└── gender_tpot_classifier
    ├── data
    │   ├── gender_all.csv
    │   ├── gender_all_transformed.csv
    │   ├── gender_test.csv
    │   ├── gender_test_transformed.csv
    │   ├── gender_train.csv
    │   └── gender_train_transformed.csv
    ├── model
    │   ├── confusion_matrix.png
    │   ├── gender_tpot_classifier.json
    │   ├── gender_tpot_classifier.pickle
    │   ├── gender_tpot_classifier.py
    │   └── gender_tpot_classifier_transform.pickle
    ├── readme.md
    ├── requirements.txt
    └── settings.json

Now you're ready to go to load these models and make predictions.

Model training scripts

Here is a quick review of all the potential default_training_script settings:

Setting License Accurate? Quick? Good docs? Classification Regression Description
'alphapy' Apache 2.0 Highly customizable setttings for data science pipelines/feature selection.
'atm' MIT License give ATM a classification problem and a dataset as a CSV file, and ATM will build the best model it can.
'autogbt' MIT License An experimental Python package that reimplements AutoGBT using LightGBM and Optuna.
'autogluon' Apache 2.0 AutoGluon: AutoML Toolkit for Deep Learning.
'autokaggle' Apache 2.0 Automated ML system trained using gbdt (regression and classification).
'autokeras' MIT License Automatic optimization of a neural network using neural architecture search (takes a very long time) - consistently has problems associated with saving and loading models in keras.
'autopytorch' Apache 2.0 Brute-Force all sklearn models with all parameters using .fit/.predict.
'btb' MIT License Hyperparameter tuning with various ML algorithms in scikit-learn using genetic algorithms.
'cvopt' BSD 2-Clause "Simplified" License Machine learning parameter search / feature selection module with visualization.
'devol' MIT License Genetic programming keras cnn layers.
'gama' Apache 2.0 An automated machine learning tool based on genetic programming.
'hungabunga' MIT License HungaBunga: Brute-Force all sklearn models with all parameters using .fit .predict!
'hyperband' BSD 3-Clause "New" or "Revised" License Implements a class HyperbandSearchCV that works exactly as GridSearchCV and RandomizedSearchCV from scikit-learn do, except that it runs the hyperband algorithm under the hood.
'hypsklearn' BSD 3-Clause "New" or "Revised" License Hyperparameter optimization on scikit-learn models.
'imbalance' MIT License Imbalance learn different ML techniques to work on data with different numbers of samples.
'keras' MIT License Simple MLP network architecture (quick prototype - if works may want to use autoML settings).
'ludwig' Apache 2.0 Deep learning (simple ludwig). - convert every feature to numerical data.
'mlblocks' MIT License Most recent framework @ MIT, regression and classification.
'neuraxle' Apache 2.0 A Sklearn-like Framework for Hyperparameter Tuning and AutoML in Deep Learning projects.
'safe' MIT License Black box trainer / helps reduce opacity of ML models while increasing accuracy.
'scsr' Apache 2.0 Simple classification / regression (built by Jim from NLX-model).
'tpot' (default) LGPL-3.0 TPOT classification / regression (autoML).

Note that you can customize the default_training_script in the settings.json. If you include multiple default training scripts in series e.g. ['keras','tpot'] it will go through and model each of these sessions serially. A sample settings.json with the ['tpot'] setting is shown below, for reference (this is the default setting):

{"version": "1.0.0", 
 "augment_data": false, 
 "balance_data": true, 
 "clean_data": false, 
 "create_csv": true, 
 "default_audio_augmenters": ["augment_tsaug"], 
 "default_audio_cleaners": ["clean_mono16hz"], 
 "default_audio_features": ["librosa_features"], 
 "default_audio_transcriber": ["deepspeech_dict"], 
 "default_csv_augmenters": ["augment_ctgan_regression"], 
 "default_csv_cleaners": ["clean_csv"], 
 "default_csv_features": ["csv_features"], 
 "default_csv_transcriber": ["raw text"], 
 "default_dimensionality_reducer": ["pca"], 
 "default_feature_selector": ["rfe"], 
 "default_image_augmenters": ["augment_imgaug"], 
 "default_image_cleaners": ["clean_greyscale"], 
 "default_image_features": ["image_features"], 
 "default_image_transcriber": ["tesseract"], 
 "default_outlier_detector": ["isolationforest"], 
 "default_scaler": ["standard_scaler"], 
 "default_text_augmenters": ["augment_textacy"], 
 "default_text_cleaners": ["remove_duplicates"], 
 "default_text_features": ["nltk_features"], 
 "default_text_transcriber": ["raw text"], 
 "default_training_script": ["tpot"], 
 "default_video_augmenters": ["augment_vidaug"], 
 "default_video_cleaners": ["remove_duplicates"], 
 "default_video_features": ["video_features"], 
 "default_video_transcriber": ["tesseract (averaged over frames)"], 
 "dimension_number": 2, 
 "feature_number": 20, 
 "model_compress": false, 
 "reduce_dimensions": false, 
 "remove_outliers": true, 
 "scale_features": true, 
 "select_features": true, 
 "test_size": 0.1, 
 "transcribe_audio": true, 
 "transcribe_csv": true, 
 "transcribe_image": true, 
 "transcribe_text": true, 
 "transcribe_video": true, 
 "transcribe_videos": true,
 "visualize_data": false}

Metrics are standardized across all model training methods to allow for interoperability across the various AutoML frameworks used. These methods differ between classification and regression models, and use the scikit-learn metrics API.

Classification

See the Classification metrics section of the user guide for further details.

  • accuracy sklearn.metrics.accuracy_score
  • balanced_accuracy metrics.balanced_accuracy_score
  • precision sklearn.metrics.precision_score
  • recall sklearn.metrics.recall_score
  • f1 sklearn.metrics.f1_score (pos_label=1)
  • f1Micro sklearn.metrics.f1_score(average='micro')
  • f1Macro sklearn.metrics.f1_score(average='macro')
  • rocAuc sklearn.metrics.roc_auc_score
  • rocAucMicro sklearn.metrics.roc_auc_score(average='micro')
  • rocAucMacro sklearn.metrics.roc_auc_score(average='macro')
  • confusion matrix

The output .JSON with metrics will look something like this:

{"sample type": "audio", "created date": "2020-05-18 17:10:24.747250", "session id": "6f81e898-bd03-4ba8-91e7-caf281748b83", "classes": ["males", "females"], "model type": "classification", "model name": "gender_atm.pickle", "metrics": {"accuracy": 0.8076923076923077, "balanced_accuracy": 0.8212121212121212, "precision": 0.7142857142857143, "recall": 0.9090909090909091, "f1_score": 0.8, "f1_micro": 0.8076923076923077, "f1_macro": 0.8074074074074074, "roc_auc": 0.8212121212121213, "roc_auc_micro": 0.8212121212121213, "roc_auc_macro": 0.8212121212121213, "confusion_matrix": [[22, 8], [2, 20]], "classification_report": "              precision    recall  f1-score   support\n\n       males       0.92      0.73      0.81        30\n     females       0.71      0.91      0.80        22\n\n    accuracy                           0.81        52\n   macro avg       0.82      0.82      0.81        52\nweighted avg       0.83      0.81      0.81        52\n"}, "settings": {"version": 1.0, "augment_data": false, "balance_data": true, "clean_data": false, "create_YAML": true, "default_audio_features": ["librosa_features"], "default_audio_transcriber": ["pocketsphinx"], "default_csv_features": ["csv_features"], "default_csv_transcriber": ["raw text"], "default_dimensionality_reducer": ["pca"], "default_feature_selector": ["lasso"], "default_image_features": ["image_features"], "default_image_transcriber": ["tesseract"], "default_scaler": ["standard_scaler"], "default_text_features": ["nltk_features"], "default_text_transcriber": "raw text", "default_training_script": ["atm"], "default_video_features": ["video_features"], "default_video_transcriber": ["tesseract (averaged over frames)"], "model_compress": false, "reduce_dimensions": false, "scale_features": true, "select_features": false, "test_size": 0.25, "transcribe_audio": false, "transcribe_csv": true, "transcribe_image": true, "transcribe_text": true, "transcribe_videos": true, "visualize_data": false}, "transformer name": "gender_atm_transform.pickle", "training data": ["gender_all.csv", "gender_train.csv", "gender_test.csv", "gender_all_transformed.csv", "gender_train_transformed.csv", "gender_test_transformed.csv"], "sample X_test": [-1.1922838749534224, -1.1929940219927937, -1.2118533388328951, -1.19399037815855, -0.0021316583752660208, -1.190766733291579, 3.849427576362443, -1.9204306379246512, -1.125006509709126, -1.545343386974132, 0.0, -1.9143559991789347, 0.0, 0.0, 0.0, 0.0, 0.0, 0.393398281153827, -1.3419713201673347, -0.31267044157215834, 0.7145414374550684, 0.4408029245727885, -0.08794721338696684, -1.2908803928098331, -0.7686039200276, 0.27355089960768003, -0.023868957327515483, -0.33493242828747644, -1.3534071112752446, -1.0663649542614346, 0.07710406397104422, -0.2620966016717233, 0.07442950075672321, -1.3586158282078342, -0.7229877787428609, 0.47998843890208265, 0.14338269561742617, 1.0289942793010851, -1.372750144952516, 0.16682659474904094, 1.3103462536843684, 1.0852524978109923, 1.3053083662553129, -1.3204536680340357, 0.44443569253872167, 1.5316554076888196, 1.3536334361962563, -0.008584151637748075, -1.3044478087345126, -0.7798374876334464, 0.4391526047447745, 0.0328217238195433, -1.6588812589246522, -1.321496997694275, -2.3058731114054853, -0.9428316620058456, -1.6054110631290472, -1.9015638712936165, -1.3181264260875947, -2.3400689045953977, -1.2257820801881496, -1.8580680903896456, -1.6875097168794329, -1.4113531659207812, -2.07877796917062, -1.0679177417254373, -1.6568549096503389, -1.7221630518945865, -1.3474956240767961, -2.135089810138456, -1.1077177226463242, -1.6750981593832228, -2.0781689670816696, -1.407674752516947, -2.466670316340356, -1.4224935683039404, -2.0228930807342236, -0.64807477922011, -0.8067816782034948, -0.6897448732292194, 0.666399646568167, -0.758504945965009, 1.686482678278769, -0.4026076385117994, 0.9932683897094231, 1.7178647246613565, 1.5792208834827586, 1.4390951796584497, -0.8982625212821831, 0.8965596316516141, 2.1707288819224306, 1.161596616542146, 0.23397581925946248, 1.5058743230731093, -0.3256390312368899, -0.1997911700368024, 0.3913640764940668, -0.23106922325335622, 0.2688364420481832, 0.6136517467870749, 0.9853170289935439, -0.45156833200238317, -0.4039027759848379, -1.4047903060201576, -1.6199422985321121, 0.6428248152306695, -0.5021531991290126, 0.6164464311350407, -1.8395730917192974, -1.1262962931860703, 1.6725808902325139, 0.5842734033850393, 0.7649135160627001, -0.2873589631759208, -0.17297527885266198, 1.2832640412296985, 0.8199787566304052, 0.17266635859916557, -1.530303895687936, -1.035762063808967, 1.2059152644116997, 0.05560008103368553, -1.4240650565517237, 0.1317016024108744, -1.5473895633630874, -0.044686224198341895, -1.622607473245282, -0.7811768474081445, -1.3574788179509854, -1.4277185112430948, 0.8006696574528496, -0.739635158959708, -0.43596672823397514, -0.9941316877059948, -1.5527334792494474, 1.0241352281932175, -0.7322079093175251, -0.6928023654465668, 2.256128519734752, -0.6674105694489904, -0.5118001652350106, -1.119649264235433, 0.41747338869357786, -0.6721303658862441, -0.6423304904495852, 0.721819384516983, 0.3675016869618989, -0.4590847986919945, -0.6915104068484828, -0.7457704134904531, 0.5219531770125143, -0.4113120309505084, -1.722127309595865, -1.6950722496795612, -2.4638679454105468, -0.8850791263479258, -1.3195505366838527, -1.3493943022990107, -0.14490564220224492, -0.431207418652445, -0.5543954310823219, -1.1040016294903932, 1.7500526600682629, -1.0972072350519622, 0.117984869151784, 3.8430207941961694, 1.7439873570898108, -0.3527571572132666, -0.29344427565369596, -0.30383194101857836, -0.27273455031977006, -0.4812192589932978, -1.9936313049319827, -1.1758892116636155, -2.7207295590332827, -0.619693296869786, -1.753005830475332, -1.6994742556679854, -1.8507272006073032, -2.2045026156861076, -0.805498984978142, -1.3546162014501322, -0.3736833281263094, -0.7998460443253107, -0.7198574596318982, 0.9646067905470591, -0.3263749779375788], "sample y_test": 0}

Regression

See the Regression metrics section of the user guide for further details.

  • max_error metric calculates the maximum residual error.
  • Mean absolute error regression loss
  • Mean squared error regression loss
  • Median absolute error regression loss
  • R^2 (coefficient of determination) regression score function.

The output .JSON will look something like this:

{"sample type": "audio", "created date": "2020-05-18 16:54:24.805218", "session id": "0d88a075-a7ab-487c-b9c2-3c7c1ae09d03", "classes": ["males", "females"], "model type": "regression", "model name": "alsdfjlsajdf_autokeras.pickle", "metrics": {"mean_absolute_error": 0.1077826815442397, "mean_squared_error": 0.07462779294115354, "median_absolute_error": 1.601874828338623e-07, "r2_score": 0.6942521937683649}, "settings": {"version": 1.0, "augment_data": false, "balance_data": true, "clean_data": false, "create_YAML": true, "default_audio_features": ["librosa_features"], "default_audio_transcriber": ["pocketsphinx"], "default_csv_features": ["csv_features"], "default_csv_transcriber": ["raw text"], "default_dimensionality_reducer": ["pca"], "default_feature_selector": ["lasso"], "default_image_features": ["image_features"], "default_image_transcriber": ["tesseract"], "default_scaler": ["standard_scaler"], "default_text_features": ["nltk_features"], "default_text_transcriber": "raw text", "default_training_script": ["autokeras"], "default_video_features": ["video_features"], "default_video_transcriber": ["tesseract (averaged over frames)"], "model_compress": false, "reduce_dimensions": false, "scale_features": true, "select_features": false, "test_size": 0.25, "transcribe_audio": false, "transcribe_csv": true, "transcribe_image": true, "transcribe_text": true, "transcribe_videos": true, "visualize_data": false}, "transformer name": "alsdfjlsajdf_autokeras_transform.pickle", "training data": ["alsdfjlsajdf_all.csv", "alsdfjlsajdf_train.csv", "alsdfjlsajdf_test.csv", "alsdfjlsajdf_all_transformed.csv", "alsdfjlsajdf_train_transformed.csv", "alsdfjlsajdf_test_transformed.csv"], "sample X_test": [0.1617091154403997, 0.0720205638907688, 0.18343183536874122, 0.1672803364150215, -0.6768015341469388, 0.017543357215669842, -0.6984277765173735, 0.38615567322347166, 0.16113189024852953, 0.4001187632663553, 0.0, 0.4873511913307515, 0.0, 0.0, 0.0, 0.0, 0.0, -0.5387127579642551, 0.41066147715164214, -0.19825322401743353, -0.5706310390061365, -0.4986647848248428, -0.49189625109547513, -0.17493751600011975, -0.4014143051606473, -0.3465660754407285, -0.45260975302246853, -0.45589495937639435, -0.4819237375486541, -0.5896133670337916, -0.330215482212022, -0.44043279851095235, -0.5234836122938473, -0.5272903362078499, -0.7088703332128212, -0.41531341914812164, -0.48748939273270225, -0.2618951732682184, -0.9856411619824048, -0.9400354843020243, -0.08771221511743747, -0.18482942654692902, -0.24959459599220649, -0.461204392217036, -0.72057074420697, -0.24348964155889352, -0.147406186520558, -0.9162291667305105, -0.85242473112346, -1.504207019296117, -0.580649405797514, -0.8350039253357034, -1.0312312405349793, -0.977872259346789, -1.601090597322523, -0.7066534998040032, -0.9583625521744131, -0.5346719885538214, -0.741154613709866, -0.9176221407192556, -0.4018183816122035, -0.425862924649035, -0.3009786272341064, -0.8443910549014533, -0.7325823664165042, -0.23267512513445213, -0.23968657015625028, -0.3224171248346584, -0.8056733325409708, -0.6135216108652851, -0.13110858180526833, -0.28141136962232266, -0.18871553667615804, -0.42816764163386845, -0.2448096902479657, -0.21662295115147656, -0.18750337165969908, 1.5811135800822052, 1.0748087568347513, 1.6267559896630304, 0.8394627144655394, 1.678672293821748, 0.8181658095231239, 0.4339603292131989, 0.7004815738393242, 0.17877752619514414, 0.9398282245795986, 0.5457165094370632, -0.5397498486758648, 0.8230854055311116, 0.684484079509053, 0.5435214979128995, -0.4720442259612947, 0.4558838955539394, 0.8613304491649149, -0.144243205486765, -0.6468251880219985, 0.9274784274905465, 0.9159086289445996, 1.1015505829020054, -0.30645164342333514, 0.9433508117147407, -0.005069235321852924, -0.6259481967763515, 0.27171487175342257, 0.22411722556396624, 0.02865051782674097, -0.00913579700429947, 0.7227844383541874, 0.1264126455463073, 0.05306725012946615, -0.09772130396457164, -0.6091616462319488, 0.7263804750235214, -0.35747120634336915, -1.2508217851043137, -0.52900673597145, 0.7117290727433093, -0.4008435053731832, 0.25951657943776824, 1.0586025389442921, 0.7687242828311188, 0.10088671872271811, 0.3803402834415134, 0.21974339908525978, -0.3821105808822284, -0.027745943999639745, 0.7102387217713214, 0.6062578953010683, 0.8163296719569486, 0.5948418634657724, 0.5621170094004995, 0.4145319391371441, 0.25245737630504206, 0.23658108163017624, 0.36976632978770213, 0.4676565049792853, 0.4816786341414505, -0.05879637131605244, 0.19011971656748472, 0.7265638597457408, 0.4231243844206609, -2.859067353164431, 3.5425423206598614, -0.11121182154990267, -2.8269772874698686, -2.8586061324370533, 2.714673144897261, 3.448792952538426, 2.7455492759982465, 0.029867744349596773, 2.7262814780550926, -1.0338981213024936, -0.3891019561013453, -0.31294708873454485, -0.9377285645101899, -1.0164696231472012, -1.1425686362538043, 0.47521014698574116, -0.4525955078651535, -0.9138467786910969, -1.2621235383545972, 0.11086650741096651, -0.48742707748473124, -0.23350661020516406, 0.4736030753566866, 0.17481630054955186, -0.4646954170267768, -0.31775626788395456, -0.3101992352332374, -0.4453270996903978, -0.5292092668805664, -1.0696347050704478, 0.037624033706366376, -0.1588525625765062, -0.7051766164562955, -1.0721717209500161, -0.9504316566439719, -0.5537952472654982, -0.2727396701260328, -0.6747603035701513, -0.9216852217847525, 3.2027683754576968, 3.771258992684294, 3.109624582693844, 0.10004242630174563, 3.2109347346201025], "sample y_test": 0}

Here are some settings that you can modify in the settings.json file related to Allie's Model API.

setting description default setting all options
augment_data whether or not to implement data augmentation policies during the model training process via default augmentation scripts. True True, False
balance_data whether or not to balance datasets during the model training process. True True, False
clean_data whether or not to clean datasets during the model training process via default cleaning scripts. False True, False
create_csv whether or not to output datasets in a nicely formatted .CSV as part of the model training process (outputs to ./data folder in model repositories) True True, False
default_audio_augmenters the default augmentation strategies used during audio modeling if augment_data == True ["augment_tsaug"] ["augment_tsaug", "augment_addnoise", "augment_noise", "augment_pitch", "augment_randomsplice", "augment_silence", "augment_time", "augment_volume"]
default_audio_cleaners the default cleaning strategies used during audio modeling if clean_data == True ["clean_mono16hz"] ["clean_getfirst3secs", "clean_keyword", "clean_mono16hz", "clean_towav", "clean_multispeaker", "clean_normalizevolume", "clean_opus", "clean_randomsplice", "clean_removenoise", "clean_removesilence", "clean_rename", "clean_utterances"]
default_audio_features default set of audio features used for featurization (list). ["standard_features"] ["audioset_features", "audiotext_features", "librosa_features", "meta_features", "mixed_features", "opensmile_features", "praat_features", "prosody_features", "pspeech_features", "pyaudio_features", "pyaudiolex_features", "sa_features", "sox_features", "specimage_features", "specimage2_features", "spectrogram_features", "speechmetrics_features", "standard_features"]
default_audio_transcriber the default transcription model used during audio featurization if trainscribe_audio == True ["deepspeech_dict"] ["pocketsphinx", "deepspeech_nodict", "deepspeech_dict", "google", "wit", "azure", "bing", "houndify", "ibm"]
default_csv_augmenters the default augmentation strategies used to augment .CSV file types as part of model training if augment_data==True ["augment_ctgan_regression"] ["augment_ctgan_classification", "augment_ctgan_regression"]
default_csv_cleaners the default cleaning strategies used to clean .CSV file types as part of model training if clean_data==True ["clean_csv"] ["clean_csv"]
default_csv_features the default featurization technique(s) used as a part of model training for .CSV files. ["csv_features_regression"] ["csv_features_regression"]
default_csv_transcriber the default transcription technique for .CSV file spreadsheets. ["raw text"] ["raw text"]
default_dimensionality_reducer the default dimensionality reduction technique used if reduce_dimensions==True ["pca"] ["pca", "lda", "tsne", "plda","autoencoder"]
default_feature_selector the default feature selector used if select_features == True ["rfe"] ["chi", "fdr", "fpr", "fwe", "lasso", "percentile", "rfe", "univariate", "variance"]
default_image_augmenters the default augmentation techniques used for images if augment_data == True as a part of model training. ["augment_imgaug"] ["augment_imgaug"]
default_image_cleaners the default cleaning techniques used for image data as a part of model training is clean_data == True ["clean_greyscale"] ["clean_extractfaces", "clean_greyscale", "clean_jpg2png"]
default_image_features default set of image features used for featurization (list). ["image_features"] ["image_features", "inception_features", "resnet_features", "squeezenet_features", "tesseract_features", "vgg16_features", "vgg19_features", "xception_features"]
default_image_transcriber the default transcription technique used for images (e.g. image --> text transcript) ["tesseract"] ["tesseract"]
default_outlier_detector the default outlier technique(s) used to clean data as a part of model training if remove_outliers == True ["isolationforest"] ["isolationforest", "zscore"]
default_scaler the default scaling technique used to preprocess data during model training if scale_features == True ["standard_scaler"] ["binarizer", "one_hot_encoder", "normalize", "power_transformer", "poly", "quantile_transformer", "standard_scaler"]
default_text_augmenters the default augmentation strategies used during model training for text data if augment_data == True ["augment_textacy"] ["augment_textacy", "augment_summary"]
default_text_cleaners the default cleaning techniques used during model training on text data if clean_data == True ["clean_textacy"] ["clean_summary", "clean_textacy"]
default_text_features default set of text features used for featurization (list). ["nltk_features"] ["bert_features", "fast_features", "glove_features", "grammar_features", "nltk_features", "spacy_features", "text_features", "w2v_features"]
default_text_transcriber the default transcription techniques used to parse raw .TXT files during model training ["raw_text"] ["raw_text"]
default_training_script the specified traning script(s) to train machine learning models. Note that if you specify multiple training scripts here that the training scripts will be executed serially (list). ["tpot"] ["alphapy", "atm", "autogbt", "autokaggle", "autokeras", "auto-pytorch", "btb", "cvopt", "devol", "gama", "hyperband", "hypsklearn", "hungabunga", "imbalance-learn", "keras", "ludwig", "mlblocks", "neuraxle", "safe", "scsr", "tpot"]
default_video_augmenters the default augmentation strategies used for videos during model training if augment_data == True ["augment_vidaug"] ["augment_vidaug"]
default_video_cleaners the default cleaning strategies used for videos if clean_data == True ["clean_alignfaces"] ["clean_alignfaces", "clean_videostabilize"]
default_video_features default set of video features used for featurization (list). ["video_features"] ["video_features", "y8m_features"]
default_video_transcriber the default transcription technique used for videos (.mp4 --> text from the video) ["tesseract (averaged over frames)"] ["tesseract (averaged over frames)"]
dimension_number the number of dimensions to reduce a dataset into if reduce_dimensions == True 100 any integer from 1 to the number of features-1
feature_number the number of features to select for via the feature selection strategy (default_feature_selector) if select_features == True 20 any integer from 1 to the number of features-1
model_compress a setting that specifies whether or not to compress machine learning models during model training False True, False
reduce_dimensions a setting that specifies whether or not to reduce dimensions via the default_dimensionality_reducer False True, False
remove_outliers a setting that specifies whether or not to remove outliers during model training via the default_outlier_detector True True, False