Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,8 @@ class AIModelsCheck : IntegrationTestRestBase() {
val metricType = "TIME_WINDOW"
val warmUpRep = 10
val maxAttemptRepair = 100 // i.e., the classifier has 10 times the chances to pick an action with non-400 response
val repairThreshold = 0.5
val weaknessThreshold = 0.6

val runIterations = 500
val saveReport = false
Expand Down Expand Up @@ -82,6 +84,9 @@ class AIModelsCheck : IntegrationTestRestBase() {
EMConfig.AIClassificationMetrics.valueOf(metricType)
config.aiResponseClassifierWarmup = warmUpRep
config.maxRepairAttemptsInResponseClassification = maxAttemptRepair
config.classificationRepairThreshold = repairThreshold
config.aIResponseClassifierWeaknessThreshold = weaknessThreshold
// config.blackBox = false
}

fun repairAction(call: RestCallAction) {
Expand Down Expand Up @@ -157,7 +162,12 @@ class AIModelsCheck : IntegrationTestRestBase() {

val metrics = aiGlobalClassifier.estimateMetrics(action.endpoint)

if (!(metrics.accuracy > 0.5 && metrics.f1Score400 > 0.5)) {
val deltaW = config.aIResponseClassifierWeaknessThreshold
if(metrics.precision400 < deltaW
|| metrics.sensitivity400 < deltaW
|| metrics.specificity < deltaW
|| metrics.npv < deltaW) {

println("The classifier is weak for $endPoint")

val result = ExtraTools.executeRestCallAction(action, baseUrlOfSut)
Expand Down Expand Up @@ -208,11 +218,11 @@ class AIModelsCheck : IntegrationTestRestBase() {

val overAllMetrics = aiGlobalClassifier.estimateOverallMetrics()

println("Overall Accuracy: ${overAllMetrics.accuracy}")
println("Overall Precision400: ${overAllMetrics.precision400}")
println("Overall Recall400: ${overAllMetrics.sensitivity400}")
println("Overall F1Score400: ${overAllMetrics.f1Score400}")
println("Overall MCC: ${overAllMetrics.mcc}")
println("Overall Accuracy: ${overAllMetrics.accuracy}")
println("Overall Precision400: ${overAllMetrics.precision400}")
println("Overall Sensitivity400: ${overAllMetrics.sensitivity400}")
println("Overall Specificity: ${overAllMetrics.specificity}")
println("Overall NPV: ${overAllMetrics.npv}")

if (saveReport) {
saveReports()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,16 +13,17 @@ class AIModelsCheckWFDEM : RestTestBase() {
}
}

val modelName = "KDE" // Choose "GAUSSIAN", "GLM", "KDE", "KNN", "NN", etc.
val modelName = "GAUSSIAN" // Choose "GAUSSIAN", "GLM", "KDE", "KNN", "NN", etc.
val encoderType = "RAW" // Choose "RAW" or "NORMAL"
val decisionMaking = "PROBABILITY" // Choose "PROBABILITY" or "THRESHOLD"
val decisionMaking = "THRESHOLD" // Choose "PROBABILITY" or "THRESHOLD"
val warmUpRep = 10
val maxAttemptRepair = 100 // i.e., the classifier has 10 times the chances to pick an action with non-400 response

val baseUrlOfSut = "http://localhost:8080"
// val swaggerUrl = "http://localhost:8080/v2/api-docs"
val swaggerUrl = "http://localhost:8080/api/v3/openapi.json"
// val swaggerUrl ="../WFD_Dataset/openapi-swagger/youtube-mock.yaml"
// val swaggerUrl = "http://localhost:8080/api/v3/openapi.json"
// val swaggerUrl ="../dataset/openapi-swagger/youtube-mock.yaml"
val swaggerUrl ="../dataset/openapi-swagger/catwatch.json"


fun runTest() {
Expand All @@ -33,7 +34,7 @@ class AIModelsCheckWFDEM : RestTestBase() {

// Add black-box Swagger parameters
args.add("--blackBox")
args.add("true")
args.add("false")

args.add("--ratePerMinute")
args.add("50000")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,11 +43,13 @@ class AIModelsCheckWFD : IntegrationTestRestBase() {
}
}

val modelName = "KDE" // Choose "GAUSSIAN", "GLM", "KDE", "KNN", "NN", etc.
val modelName = "GAUSSIAN" // Choose "GAUSSIAN", "GLM", "KDE", "KNN", "NN", etc.
val encoderType = "RAW" // Choose "RAW" or "NORMAL"
val decisionMaking = "THRESHOLD" // Choose "PROBABILITY" or "THRESHOLD"
val warmUpRep = 10
val maxAttemptRepair = 100 // i.e., the classifier has 10 times the chances to pick an action with non-400 response
val repairThreshold = 0.5
val weaknessThreshold = 0.6

val runIterations = 1000
val saveReport = false
Expand All @@ -56,9 +58,13 @@ class AIModelsCheckWFD : IntegrationTestRestBase() {
val baseUrlOfSut = "http://localhost:8080"
// val swaggerUrl = "http://localhost:8080/v2/api-docs"
// val swaggerUrl = "http://localhost:8080/api/v3/openapi.json"
val swaggerUrl ="../WFD_Dataset/openapi-swagger/youtube-mock.yaml"
// val swaggerUrl ="../WFD_Dataset/openapi-swagger/languagetool.json"
// val swaggerUrl = "../WFD_Dataset/openapi-swagger/rest-ncs.json"

// val swaggerUrl ="../dataset/openapi-swagger/youtube-mock.yaml"
val swaggerUrl ="../dataset/openapi-swagger/catwatch.json"
// val swaggerUrl ="../dataset/openapi-swagger/blogapi.json"
// val swaggerUrl ="../dataset/openapi-swagger/languagetool.json"
// val swaggerUrl = "../dataset/openapi-swagger/rest-ncs.json"
// val swaggerUrl = "../dataset/openapi-swagger/cwa-verification.json"

@Inject
lateinit var randomness: Randomness
Expand All @@ -76,6 +82,7 @@ class AIModelsCheckWFD : IntegrationTestRestBase() {
injector = init(args)
}


fun initializeTest() {
recreateInjectorForBlackBox(listOf("--aiModelForResponseClassification", "$modelName"))
}
Expand All @@ -91,6 +98,9 @@ class AIModelsCheckWFD : IntegrationTestRestBase() {
config.aiClassifierRepairActivation = EMConfig.AIClassificationRepairActivation.valueOf(decisionMaking)
config.aiResponseClassifierWarmup = warmUpRep
config.maxRepairAttemptsInResponseClassification = maxAttemptRepair
config.classificationRepairThreshold = repairThreshold
config.aIResponseClassifierWeaknessThreshold = weaknessThreshold
config.blackBox = false
}

@Inject
Expand Down Expand Up @@ -138,6 +148,15 @@ class AIModelsCheckWFD : IntegrationTestRestBase() {
.joinToString(", ") { ng ->
"${ng.gene.name}:${ng.gene::class.simpleName ?: "Unknown"}" })

println(
"Parameter paths and encoded values are: " +
encoder.getAllParamsPathsAndEncodedValues()
.entries
.joinToString(", ") { (name, value) ->
"$name:$value"
}
)

if (encoder.areAllGenesUnSupported()) {
println("Skipping classification for $endPoint as all its genes are unsupported.")
continue
Expand All @@ -148,8 +167,13 @@ class AIModelsCheckWFD : IntegrationTestRestBase() {
println("Input vector size: ${inputVector.size}")

// Warm-up
val innerModel = aiGlobalClassifier.viewInnerModels()
println("innerModel is ${innerModel.javaClass.simpleName ?: "Unknown"}")
val innerModels = aiGlobalClassifier.viewInnerModels()

val innerModel = innerModels.firstOrNull()
?: throw IllegalStateException("No inner models found")

println("innerModel is ${innerModel.javaClass.simpleName}")

val endpointModel = when(innerModel) {
is Gaussian400Classifier -> innerModel.getModel(endPoint)
is GLM400Classifier -> innerModel.getModel(endPoint)
Expand All @@ -166,7 +190,11 @@ class AIModelsCheckWFD : IntegrationTestRestBase() {
val metrics = aiGlobalClassifier.estimateMetrics(action.endpoint)

//Execute the action if the classifier is still weak
if(!(metrics.accuracy > 0.5 && metrics.f1Score400 > 0.0 && metrics.mcc > 0.0)){
val deltaW = config.aIResponseClassifierWeaknessThreshold
if(metrics.precision400 < deltaW
|| metrics.sensitivity400 < deltaW
|| metrics.specificity < deltaW
|| metrics.npv < deltaW) {

println("The classifier is weak for $endPoint")
val result = ExtraTools.executeRestCallAction(action, "$baseUrlOfSut")
Expand Down Expand Up @@ -244,11 +272,11 @@ class AIModelsCheckWFD : IntegrationTestRestBase() {
}

val overAllMetrics = aiGlobalClassifier.estimateOverallMetrics()
println("Overall Accuracy: ${overAllMetrics.accuracy}")
println("Overall Precision400: ${overAllMetrics.precision400}")
println("Overall Recall400: ${overAllMetrics.sensitivity400}")
println("Overall F1Score400: ${overAllMetrics.f1Score400}")
println("Overall MCC: ${overAllMetrics.mcc}")
println("Overall Accuracy: ${overAllMetrics.accuracy}")
println("Overall Precision400: ${overAllMetrics.precision400}")
println("Overall Sensitivity400: ${overAllMetrics.sensitivity400}")
println("Overall Specificity: ${overAllMetrics.specificity}")
println("Overall NPV: ${overAllMetrics.npv}")

// Save the final result as a .txt file
if (saveReport){
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,12 +58,12 @@ object ExtraTools {
| Actual 400 | TP=${metrics.truePositive400.toString().padEnd(10)}| FN=${metrics.falseNegative400.toString().padEnd(11)}|
| Actual¬400 | FP=${metrics.falsePositive400.toString().padEnd(10)}| TN=${metrics.trueNegative400.toString().padEnd(11)}|
+-------------------------------------------+
Window Total : ${metrics.windowTotal}
Accuracy : ${"%.4f".format(metrics.estimateMetrics().accuracy)}
Precision400 : ${"%.4f".format(metrics.estimateMetrics().precision400)}
Recall400 : ${"%.4f".format(metrics.estimateMetrics().sensitivity400)}
F1Score400 : ${"%.4f".format(metrics.estimateMetrics().f1Score400)}
MCC400 : ${"%.4f".format(metrics.estimateMetrics().mcc)}
Window Total : ${metrics.windowTotal}
Accuracy : ${"%.4f".format(metrics.estimateMetrics().accuracy)}
Precision400 : ${"%.4f".format(metrics.estimateMetrics().precision400)}
Sensitivity400 : ${"%.4f".format(metrics.estimateMetrics().sensitivity400)}
Specificity : ${"%.4f".format(metrics.estimateMetrics().specificity)}
NPV : ${"%.4f".format(metrics.estimateMetrics().npv)}
""".trimIndent()
)
}
Expand Down Expand Up @@ -107,12 +107,12 @@ object ExtraTools {
sb.appendLine("| Actual¬400 | FP=${it.falsePositive400.toString().padEnd(10)}| TN=${it.trueNegative400.toString().padEnd(11)}|")
sb.appendLine("+-------------------------------------------+")
sb.appendLine()
sb.appendLine("Window Total : ${it.windowTotal}")
sb.appendLine("Accuracy : ${"%.4f".format(it.estimateMetrics().accuracy)}")
sb.appendLine("Precision400 : ${"%.4f".format(it.estimateMetrics().precision400)}")
sb.appendLine("Recall400 : ${"%.4f".format(it.estimateMetrics().sensitivity400)}")
sb.appendLine("F1Score400 : ${"%.4f".format(it.estimateMetrics().f1Score400)}")
sb.appendLine("MCC400 : ${"%.4f".format(it.estimateMetrics().mcc)}")
sb.appendLine("Window Total : ${it.windowTotal}")
sb.appendLine("Accuracy : ${"%.4f".format(it.estimateMetrics().accuracy)}")
sb.appendLine("Precision400 : ${"%.4f".format(it.estimateMetrics().precision400)}")
sb.appendLine("Sensitivity400 : ${"%.4f".format(it.estimateMetrics().sensitivity400)}")
sb.appendLine("Specificity : ${"%.4f".format(it.estimateMetrics().specificity)}")
sb.appendLine("NPV : ${"%.4f".format(it.estimateMetrics().npv)}")
sb.appendLine()
sb.appendLine("=============================================")
sb.appendLine()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ abstract class AbstractProbabilistic400Classifier<T : AIModel>(
return
}

// create an endPoint model if it does not exist
val m = models.getOrPut(endpoint) {
val encoder = InputEncoderUtilWrapper(input, encoderType = encoderType)

Expand All @@ -46,20 +47,26 @@ abstract class AbstractProbabilistic400Classifier<T : AIModel>(
return@getOrPut null
}

val listGenes = encoder.endPointToGeneList().map { it.gene.getLeafGene() }
val initialParamPaths = encoder.getAllParamsPathsAndEncodedValues().keys.toList()

createEndpointModel(
endpoint, warmup,
listGenes.size,
endpoint,
warmup,
initialParamPaths,
initialParamPaths.size,
encoderType,
metricType,
randomness)
randomness
)
}


if (m == null) {
unsupportedEndpoints.add(endpoint)
return
}

// update the endpoint model and initialize if needed
m.updateModel(input, output)
}

Expand Down Expand Up @@ -99,6 +106,7 @@ abstract class AbstractProbabilistic400Classifier<T : AIModel>(
protected abstract fun createEndpointModel(
endpoint: Endpoint,
warmup: Int,
modelKeys: List<String>,
dimension: Int,
encoderType: EMConfig.EncoderType,
metricType: EMConfig.AIClassificationMetrics,
Expand Down
Loading