A Voyage AI se une ao MongoDB para impulsionar aplicativos de AI mais precisos e confiáveis no Atlas.

Explore o novo chatbot do Developer Center! O MongoDB AI chatbot pode ser acessado na parte superior da sua navegação para responder a todas as suas perguntas sobre o MongoDB .

Desenvolvedor do MongoDB
Centro de desenvolvedores do MongoDB
chevron-right
Idiomas
chevron-right
Java
chevron-right

Como construir um modelo de detecção de fraudes em Java usando o aprendizado profundo4J

Tim Kelly15 min read • Published Jan 27, 2025 • Updated Jan 27, 2025
IAJava
APLICATIVO COMPLETO
Ícone do FacebookÍcone do Twitterícone do linkedin
Avalie esse Tutorial
star-empty
star-empty
star-empty
star-empty
star-empty
Neste tutorial, aprenderemos como combinar o poder das redes generativas e o gerenciamento de dados escalável para criar um sistema de detecção de fraudes do mundo real. A IA não é mais apenas para JavaScript e Python. Com a ajuda do deep learning ,4J, treinaremos uma rede causal no Java e, com o MongoDB, gerenciaremos e armazenaremos os dados da transação de forma eficiente.
Com toda uma série de dados sobre transações de cliente , vamos aprender ou modelar como identificar as fraudes. Quer se trate de uma quantidade suspeita, local estranho ou horário singular, muitas variáveis estão em jogo e é importante fazer isso corretamente! Pré-processaremos nossos dados de transação e treinaremos uma rede causal para integrar o MongoDB para armazenamento de dados escalável. No final, teremos um aplicação totalmente funcional capaz de identificar transações enganosas.
Se você quiser clonar este repositório ou apenas visualizar todo o código, confira o repositório do Github.

O que construiremos

Criaremos um sistema que:
  1. Carrega e pré-processa dados de transação de um arquivo CSV.
  2. Treina uma rede causal para classificar transações como maliciosos ou não maliciosos.
  3. Armazena e recupera dados de transação no MongoDB, garantindo escalabilidade e persistência.

Por que Java para IA?

Java muitas vezes passa desfocado no mundo da IA, mas Java tem alguns pontos fortes inegáveis que o tornam uma escolha sólida para construir sistemas de IA, especialmente quando precisamos ir além dos experimentos e entrar em nossos ambientes de produção escaláveis.
  • Integração com sistemas empresariais: Java é uma base no software empresarial. Quando queremos incorporar IA em sistemas existentes, a predominância do Java o torna natural.
  • Desempenho e escalabilidade: com suas otimizações multithreading e JVM, o Java lida com aplicativos distribuídos e de alto desempenho com eleg
Ao usar Java para IA, não estamos apenas construindo algo legal — estamos criando algo robusto, escalável e pronto para lidar com as cobranças do mundo real da produção. Já chega de tagarelando sobre por que isso é legal, vamos começar a construir algo.

O que abordaremos

  1. Pré-processamento de dados: como limpar e preparar dados de transação para treinamento
  2. Treinando a rede causal: Usando4o aprendizado a fundo J para treinar um modelo de classificação
  3. Integração ao MongoDB: Gerenciando dados de transações em um banco de dados escalável e eficiente
  4. Interação em tempo real: construindo uma CLI interativa para testar nosso modelo e executar predições. Isso pode ser facilmente uma API ou qualquer outra maneira de você planejar interagir com seu aplicação.

Pré-requisitos

Para acompanhar este tutorial, você deve ter o seguinte:

O que é o deep learning4J?

Kerry4J é uma biblioteca de aprendizagem profunda baseada em Java projetada para desenvolvedores que desejam criar sistemas de IA prontos para produção sem sair do ecossistema JVM . Ao contrário de algumas das ferramentas mais experimentais no mundo da IA, oDeopearing4J é muito maismaduro e se concentra em casos de uso do mundo real, oferecendo recursos que se integram a aplicativos empresariais.
Correndo o risco de soar como uma ação publicitária para isso, veja o que a destaca:
  • Suporte nativo a Java: foi criado para desenvolvedores Java, para que possamos usar ferramentas e fluxos de trabalho conhecidos para criar modelos de aprendizado profundo (como o MongoDB).
  • Escalabilidade: o deep learning4J suporta treinamento distribuído imediatamente, o que o torna ideal para grandes conjuntos de dados e aplicativos de alto desempenho.
  • Flexibilidade: estejamos trabalhando em classificação, regressão ou arquiteturas mais complexas, a biblioteca fornece os blocos de construção para personalizar e otimizar nossos modelos.
  • Integração: adestramento detalhado4J funciona bem com outras ferramentas Java, incluindo o Apache Spark para processamento de big data e o Hadoop para sistemas distribuídos.
O deep learning4j se destaca por ser mais do que apenas experimentar IA em Java— estamos construindo modelos que podem ser implantados diretamente em ambientes de produção, junto com o restante de nossos aplicativos Java .

Entendendo os dados

Explorando nosso conjunto de dados de fraudes com cartão de crédito

Esse conjunto de dados tem as seguintes características principais:
  1. Tamanho e composição:
    • Total de transações: 284,807
    • Transações enganosas: 492 (0.172% dos dados)
    • O conjunto de dados é altamente desequilibrado, pois as transações enganosas representam uma porcentagem muito pequena do total. Isso reflete o desafio do mundo real, em que a fraude é rara, mas crítica de detectar.
  2. Recursos:
    • Tempo: os segundos decorridos entre uma transação e a primeira transação no conjunto de dados.
    • Valor: o valor monetário da transação.
    • V1 a V28: componentes principais resultantes da PCA (Análise de Componentes Principais), usados para tornar dados confidenciais anônimos.
    • Classe: a variável de destino, em que 0 indica uma transação válida e 1 indica fraude.

Reconhecendo o desequilíbrio de classe

O desequilíbrio extremo de classe (transações enganosas representam apenas 0.172% do conjunto de dados) pode complicar o aprendizado de máquina:
  • Os modelos podem se tornar enviesados para a classe majoritária (0), levando a um desempenho ruim na classe majoritária (1).
Para este tutorial, não vamos nos concentrar em abordar o desequilíbrio. Em vez disso, treinaremos e avaliaremos o modelo no conjunto de dados original, tratando-o como um exercício de aprendizado. No entanto, técnicas como reamostragem, funções de perda ponderada ou aprendizado sensível ao custo seriam essenciais para lidar com um conjunto de dados desequilibrado.

Preparando os dados

Para preparar os dados para o nosso modelo, seguiremos estas etapas:
  1. Selecionando funcionalidades:
    • Começaremos com o recurso Valor como nossa entrada.
    • O recurso de tempo e os componentes de PCA (V1 a V28) serão excluídos por simplicidade, mas podem ser explorados posteriormente para aprimorar o modelo. Realizei um teste com eles mais tarde, e você poderá ver por que decidimos excluí-los para nossa implementação simples.
  2. Variável de destino:
    • A coluna Classificação serve como nossa variável de destino, com 0 para transações válidas e 1 para fraudes.
  3. Normalizando dados:
    • As redes mentais têm melhor desempenho com dimensionamento de entrada consistente, portanto, normalizaremos a coluna Valor para uma faixa de 0 a 1.
  4. Dividindo os dados:
    • Dividiremos o conjunto de dados em um conjunto de treinamento (80% ) e um conjunto de teste (20% ) para avaliar se nosso modelo aplica bem os dados não vistos.
Ao manter o pré-processamento simples, podemos nos concentrar na compreensão da mecânica da construção e do treinamento de uma rede causal. Embora não estejamos abordando o desequilíbrio de classe diretamente neste tutorial, ele continua sendo uma consideração absolutamente crítica para aplicativos do mundo real. Retornarei este ponto algumas vezes ao longo deste tutorial, mas isso é porque ele é provavelmente o fator que mais influencia nossos resultados neste modelo. Vamos passar para o código de pré-processamento real!

Nossas dependências

Primeiro, precisamos adicionar nossas dependências ao nosso POM.
1<dependencies>
2 <dependency>
3 <groupId>org.deeplearning4j</groupId>
4 <artifactId>deeplearning4j-core</artifactId>
5 <version>1.0.0-M2.1</version>
6 </dependency>
7 <dependency>
8 <groupId>org.nd4j</groupId>
9 <artifactId>nd4j-native-platform</artifactId>
10 <version>1.0.0-M2.1</version>
11 </dependency>
12 <dependency>
13 <groupId>org.mongodb</groupId>
14 <artifactId>mongodb-driver-sync</artifactId>
15 <version>5.2.0</version>
16 </dependency>
17 <dependency>
18 <groupId>org.slf4j</groupId>
19 <artifactId>slf4j-simple</artifactId>
20 <version>2.0.16</version>
21 </dependency>
22
23</dependencies>
Aqui, estamos importando algumas dependências para oDeepearing4j, bem como o driver Java do MongoDB . Também temos slf4j apenas para registro e nd4j. Nd4j fornece métodos de conveniência para a criação de arrays a partir de arrays Java flutuantes e duplas.

Configurando nossa conexão MongoDB

Crie uma MongoDBConnector classe . será assim que estabelecemos nossa conexão com nosso banco de dados MongoDB , onde armazenaremos nossos dados.
Para URI o, adicione sua string de conexão para seu banco de dados. Altere o nome do banco de dados e da coleção para o que fizer sentido, mas estou usando fraudDection e transactions, respectivamente, para este exemplo.
1package com.mongodb;
2
3import com.mongodb.client.MongoClient;
4import com.mongodb.client.MongoClients;
5import com.mongodb.client.MongoCollection;
6import com.mongodb.client.MongoDatabase;
7import org.bson.Document;
8
9import java.util.concurrent.TimeUnit;
10
11public class MongoDBConnector {
12 private static final String URI = "YOUR-CONNECTION-STRING";
13 private static final String DATABASE_NAME = "fraudDetection";
14 private static final String COLLECTION_NAME = "transactions";
15
16 private final MongoClient mongoClient;
17
18 public MongoDBConnector() {
19 MongoClientSettings settings = MongoClientSettings.builder()
20 .applyConnectionString(new ConnectionString(URI))
21 .applyToSocketSettings(builder ->
22 builder.connectTimeout(30, TimeUnit.SECONDS)
23 .readTimeout(30, TimeUnit.SECONDS))
24 .build();
25
26 mongoClient = MongoClients.create(settings);
27 }
28
29 public MongoCollection<Document> getCollection() {
30 return mongoClient.getDatabase(DATABASE_NAME).getCollection(COLLECTION_NAME);
31 }
32
33}
Também estamos aplicando configurações específicas de tempo limite ao nosso MongoClient. Como estamos carregando uma grande quantidade de dados para nosso banco de dados em grandes blocos, isso nos ajudará a trabalhar em torno de quaisquer exceções de tempo limite causadas por problemas de rede.

Criando a transação POJO

Para trabalhar com nossos dados de transação, criaremos uma Transaction classe . Vamos simplificá-lo para esta demonstração e apenas usar o valor e se ele foi marcado como malicioso ou não.
Isso afetará significativamente a confiabilidade do nosso modelo, portanto, fique à vontade para adicionar mais recursos para classificar como desejar. A IA é um mundo complexo, e modelos de treinamento são uma habilidade real. Funcionalidades diferentes serão mais confiáveis para fazer projeções enganosas. Pense em transações grandes no lado mais distante do mundo ou durante horários ímpares para esse usuário específico.
1package com.mongodb;
2
3import org.bson.Document;
4
5public class Transaction {
6 private double amount;
7 private boolean isFraudulent;
8
9 public Transaction(double amount, boolean isFraudulent) {
10 this.amount = amount;
11 this.isFraudulent = isFraudulent;
12 }
13
14 public double getAmount() { return amount; }
15 public boolean isFraudulent() { return isFraudulent; }
16
17 public Document toDocument() {
18 Document doc = new Document();
19 doc.append("amount", amount);
20 doc.append("isFraudulent", isFraudulent);
21 return doc;
22 }
23}
Isso fornecerá uma estrutura simples para nossos dados e nos permitirá convertê-los em um formato compatível com o MongoDB com o toDocument método.

Armazenando e salvando nossos dados

Em seguida, precisamos de uma TransactionRepository classe . É aqui que encapsulamos as operações para salvar e buscar nossas transações do banco de dados. Essa abstração nos permite manter nossa lógica de acesso a dados organizada e reutilizável.
1package com.mongodb;
2
3import com.mongodb.client.MongoCollection;
4import com.mongodb.client.model.BulkWriteOptions;
5import com.mongodb.client.model.WriteModel;
6import org.bson.Document;
7
8import java.util.ArrayList;
9import java.util.List;
10
11public class TransactionRepository {
12 private final MongoCollection<Document> collection;
13
14 public TransactionRepository(MongoDBConnector connector) {
15 this.collection = connector.getCollection();
16 }
17
18 public void bulkSaveTransactions(List<WriteModel<Document>> transactions) {
19 if (!transactions.isEmpty()) {
20 try {
21 BulkWriteOptions options = new BulkWriteOptions().ordered(true);
22 collection.bulkWrite(transactions, options); // Perform bulk write operation
23 } catch (Exception e) {
24 e.printStackTrace();
25 }
26 }
27 }
28
29 public List<Transaction> getAllTransactions() {
30 List<Transaction> transactions = new ArrayList<>();
31 for (Document doc : collection.find()) {
32 double amount = doc.getDouble("amount");
33 boolean isFraudulent = doc.getBoolean("isFraudulent");
34 transactions.add(new Transaction(amount, isFraudulent));
35 }
36 return transactions;
37 }
38
39}
Esse padrão de repositório facilita a adição de mais operações de banco de dados , como filtrar nossas transações ou executar queries mais avançadas.

Pré-processamento de dados em Java

O pré-processamento eficaz dos dados é uma etapa essencial no aprendizado de máquina. Aqui, vamos carregar os dados do creditcard.csv arquivo, que devem ser colocados no src/main/resources diretório. Leia os campos nos quais queremos focar em nosso treinamento de modelo e formate-o para nosso modelo de transação.
1package com.mongodb;
2
3import com.mongodb.client.model.WriteModel;
4import com.mongodb.client.model.InsertOneModel;
5import org.bson.Document;
6
7import java.io.BufferedReader;
8import java.io.FileReader;
9import java.io.IOException;
10import java.util.ArrayList;
11import java.util.List;
12import java.util.stream.Collectors;
13
14public class DataPreprocessor {
15
16 TransactionRepository transactionRepository;
17
18 private static final int BATCH_SIZE = 500; // Define batch size for bulk writes
19 private static final int DOCUMENT_LIMIT = 250000; // Maximum documents to insert
20
21 private int documentCount = 0; // Counter for total inserted documents
22
23 public DataPreprocessor(TransactionRepository transactionRepository) {
24 this.transactionRepository = transactionRepository;
25 }
26
27 public void loadData(String filePath) throws IOException {
28 try (BufferedReader reader = new BufferedReader(new FileReader(filePath))) {
29 // Skip the header by reading the first line
30 reader.readLine();
31
32 List<String> batch = new ArrayList<>();
33 String line;
34
35 // Read the file line-by-line
36 while ((line = reader.readLine()) != null) {
37 if (documentCount >= DOCUMENT_LIMIT) {
38 System.out.println("Reached the document limit of " + DOCUMENT_LIMIT + ". Stopping data load.");
39 break; // Stop processing when the limit is reached
40 }
41
42 batch.add(line);
43 documentCount++;
44
45 // When batch size is reached, process it
46 if (batch.size() == BATCH_SIZE) {
47 processBatch(batch);
48 batch.clear(); // Clear the batch for the next set of lines
49 }
50 }
51
52 // Process any remaining lines
53 if (!batch.isEmpty()) {
54 processBatch(batch);
55 }
56 }
57 }
58
59 private void processBatch(List<String> batch) {
60 List<WriteModel<Document>> bulkOperations = batch.stream()
61 .map(line -> {
62 String[] fields = line.split(",");
63 double amount = Double.parseDouble(fields[29]); // Adjust index as needed
64 boolean isFraudulent = "1".equals(fields[30]);
65 Transaction transaction = new Transaction(amount, isFraudulent);
66 return new InsertOneModel<>(transaction.toDocument());
67 })
68 .collect(Collectors.toList());
69 transactionRepository.bulkSaveTransactions(bulkOperations);
70 }
71
72}
Também estamos limitando o número de documentos adicionados ao banco de dados para 250,000. Isso ocorre porque estamos usando o cluster de camada MongoDB M0. É limitado a 500MB e não queremos exceder isso.

Construindo uma rede Neural com o Deapearing4J

Para detectar as transações enganosas, usaremos o deeplearing4J para construir nossa rede causal. Faremos isso em uma FraudDetectionModel classe .

Configurando a rede causal

Vamos configurar uma rede feedforward simples para especificar se uma transação é enganosa ou não.
  • Camada de entrada (implicita):
    cada rede generativa começa com uma camada de entrada que recebe os dados. No deep learning4J, isso é tratado automaticamente pela primeira camada. O nIn(NUM_INPUT_FEATURES) no primeiro DenseLayer especifica quantas funcionalidades de entrada o modelo aceitará. Por enquanto, estamos usando apenas o recurso Amount, mas você pode expandi-lo facilmente.
  • Camada compacta (camada oculta):
    este é o núcleo do nosso modelo onde o aprendizado acontece. Ela consiste em 10 conexões e usa a função de ativação ReLU (Unidade Linear Retificada), que introduz não linearidade para ajudar o modelo a aprender padrões complexos nos dados.
  • Camada de saída:
    a camada de saída prever se uma transação é enganosa ou não. Ela tem dois nós — um para cada classe (fraudulenta ou legitime) — e usa a função de ativação do Sofmax para gerar probabilidades para cada classe. A classe com a maior probabilidade se torna a predição do modelo.
  • Função de perda e otimizador:
    • O modelo usa a função de perda de probabilidade de registro negativo, ideal para problemas de classificação.
    • Ele é otimizado usando o otimizadorAdam com uma taxa de aprendizado de 0.001, que adapta a taxa de aprendizado ao longo do treinamento para melhor convergência.
    • A inicialização do peso Xavier é usada para manter os pesos iniciais equilibrados, evitando problemas como gradientes que desaparecem ou explodem.
  • Feedback do treinamento:
    Adicionamos um ScoreIterationListener que gera o progresso de treinamento do modelo a cada 10 iterações, nos fornecendo insights sobre quão bem o modelo está aprendendo.
Se nada do que eu terminei de fazer faz sentido, não se desespere, pois eu também não fazia. Mostrei como implementá-la abaixo, para que você possa vê-la em ação. A IA é um campo grande, e muitas pessoas inteligentes fizeram um trabalho surpreendente para torná-lo mais acessível. Confira o Centro de Desenvolvedores do MongoDB para saber mais sobre IA.
1package com.mongodb;
2
3import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
4import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
5import org.deeplearning4j.nn.conf.layers.DenseLayer;
6import org.deeplearning4j.nn.conf.layers.OutputLayer;
7import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
8import org.deeplearning4j.nn.weights.WeightInit;
9import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
10import org.nd4j.evaluation.classification.Evaluation;
11import org.nd4j.linalg.activations.Activation;
12import org.nd4j.linalg.api.ndarray.INDArray;
13import org.nd4j.linalg.dataset.DataSet;
14import org.nd4j.linalg.factory.Nd4j;
15import org.nd4j.linalg.learning.config.Adam;
16import org.nd4j.linalg.lossfunctions.LossFunctions;
17
18import java.util.Collections;
19import java.util.List;
20
21public class FraudDetectionModel {
22 private MultiLayerNetwork model;
23
24 private static final int NUM_INPUT_FEATURES = 1;
25 private static final int NUM_CLASSES = 2;
26
27 public FraudDetectionModel() {
28 initializeModel();
29 }
30
31 private void initializeModel() {
32 MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
33 .seed(123)
34 .weightInit(WeightInit.XAVIER)
35 .updater(new Adam(0.001))
36 .list()
37 .layer(0, new DenseLayer.Builder()
38 .nIn(NUM_INPUT_FEATURES)
39 .nOut(10)
40 .activation(Activation.RELU)
41 .build())
42 .layer(1, new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD)
43 .nIn(10)
44 .nOut(NUM_CLASSES)
45 .activation(Activation.SOFTMAX)
46 .build())
47 .build();
48
49 model = new MultiLayerNetwork(conf);
50 model.init();
51 model.setListeners(new ScoreIterationListener(10));
52 }
53
54}
Então, o que média tudo isso? Bem, vamos explorar os componentes dessa classe e o que cada parte faz.

Variáveis de classe

1private MultiLayerNetwork model;
2private static final int NUM_INPUT_FEATURES = 1;
3private static final int NUM_CLASSES = 2;
model: A instância da rede causal que será habilitada
  • NUM_INPUT_FEATURES: o número de funcionalidades de entrada — atualmente definido como 1 porque somente o valor da transação é usado
  • NUM_CLASSES: definido como 2, representando dois resultados possíveis: legítimo (0) ou malicioso (1)

Inicialização do modelo

1private void initializeModel() {
2 MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
3 .seed(123)
4 .weightInit(WeightInit.XAVIER)
5 .updater(new Adam(0.001))
6 .list()
  • NeuralNetConfiguration.Builder(): usado para definir a arquitetura de rede.
  • seed(123): define uma semente aleatória para reprodutibilidade.
  • weightInit(WeightInit.XAVIER): Inicializa os pesos da rede usando a inicialização Apache, que equilibra a variação dos pesos, ajudam a rede a treinar com eficiência.
  • updater(new Adam(0.001)): usa o otimizadorAdam com uma taxa de aprendizado de 0.001 para aprendizado adaptável

Camada oculta (camada compacta)

1.layer(0, new DenseLayer.Builder()
2 .nIn(NUM_INPUT_FEATURES)
3 .nOut(10)
4 .activation(Activation.RELU)
5 .build())
  • Camada 0: a primeira camada oculta
  • nIn(NUM_INPUT_FEATURES): número de entradas, atualmente 1 (valor da transação)
  • nOut(10): Número de conexões nesta camada (10 conexões)
  • Activation.RELU: A ativação ReLU introduz não linearidade, permitindo que o modelo aprenda relacionamentos complexos

Camada de saída

1.layer(1, new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD)
2 .nIn(10)
3 .nOut(NUM_CLASSES)
4 .activation(Activation.SOFTMAX)
5 .build())
  • Camada 1: a camada de saída
  • LossFunction.NEGATIVELOGLIKELIHOOD: Uma função de perda adequada para tarefas de classificação
  • nIn(10): Tira as saídas 10 da camada oculta
  • nOut(NUM_CLASSES): Produz dois valores representando as probabilidades de cada classe (fraucesso ou legítimo)
  • Activation.SOFTMAX: Converte saídas em probabilidades de classe que somam até 1.

Inicialização do modelo e feedback de treinamento

1model = new MultiLayerNetwork(conf);
2model.init();
3model.setListeners(new ScoreIterationListener(10));
  • model = new MultiLayerNetwork(conf);: Cria a rede causal usando a configuração
  • model.init();: Inicializa os parâmetros da rede
  • setListeners(new ScoreIterationListener(10));: registra a pontuação do modelo (erro) a cada 10 iterações para acompanhar o progresso do treinamento
Em seguida, adicionaremos o método prepareTrainingData à nossa classe. A rede espera que os dados estejam na forma de feições (entradas) e rótulos (destinos).
1 private DataSet prepareTrainingData(List<Transaction> transactions) {
2 int numTransactions = transactions.size();
3 INDArray features = Nd4j.create(numTransactions, NUM_INPUT_FEATURES);
4 INDArray labels = Nd4j.create(numTransactions, NUM_CLASSES);
5 DataSet dataSet;
6
7 for (int i = 0; i < numTransactions; i++) {
8 Transaction transaction = transactions.get(i);
9
10 // Use transaction amount as feature
11 features.putScalar(new int[]{i, 0}, transaction.getAmount());
12
13 // One-hot encoding for labels
14 if (transaction.isFraudulent()) {
15 labels.putScalar(new int[]{i, 1}, 1.0);
16 labels.putScalar(new int[]{i, 0}, 0.0);
17 } else {
18 labels.putScalar(new int[]{i, 0}, 1.0);
19 labels.putScalar(new int[]{i, 1}, 0.0);
20 }
21 }
22
23 return dataSet = new DataSet(features, labels);
24 }
Treinamos o modelo com um loop simples em várias épocas. Uma época é um termo em aprendizado de máquina que se refere a uma passagem completa do conjunto de dados completo pelo algoritmo de aprendizado.
1 public void trainModel(List<Transaction> transactions) {
2 // Shuffle the data to ensure random distribution
3 Collections.shuffle(transactions);
4
5 // Prepare training data
6 DataSet dataSet = prepareTrainingData(transactions);;
7
8 // Train the model
9 for (int epoch = 0; epoch < 100; epoch++) {
10 model.fit(dataSet);
11 }
12
13 System.out.println("Model trained successfully.");
14 }
Após o treinamento, adicionaremos um método evaluateModel para medir o desempenho dos modelos. Isso é crucial no treinamento do modelo de IA, a fim de refinar nossa metodologia de como queremos implementar a rede causal.
1public void evaluateModel(List<Transaction> transactions) {
2 // Split data into train and test sets
3 int trainSize = (int)(transactions.size() * 0.8);
4 List<Transaction> testSet = transactions.subList(trainSize, transactions.size());
5
6 // Prepare test data
7 DataSet testData = prepareTrainingData(testSet);
8
9 // Perform evaluation
10 Evaluation evaluation = new Evaluation(NUM_CLASSES);
11 INDArray predicted = model.output(testData.getFeatures());
12 evaluation.eval(testData.getLabels(), predicted);
13
14 // Print evaluation statistics
15 System.out.println(evaluation.stats());
16}
Veremos isso em nosso aplicação principal mais tarde e aprenderemos como interpretar nossos resultados.
Por último, adicionaremos um método para prever fraudes para uma transação singular.
1 public boolean predictFraud(Transaction transaction) {
2 // Convert transaction to INDArray
3 INDArray input = Nd4j.create(new double[][]{{transaction.getAmount()}});
4
5 // Perform prediction
6 INDArray output = model.output(input);
7
8 // Interpret the output
9 // Index 1 corresponds to fraud class (assuming one-hot encoding)
10 return output.getDouble(0, 1) > 0.5;
11 }
Podemos usar isso se quisermos gerar algumas transações sintéticas para testar nosso modelo depois de treinamento.

Modelo de detecção de fraudes em ação

Agora é hora de colocar todas as peças em ação. Criaremos FraudDetectionApp para manter todos os nossos componentes.
1package com.mongodb;
2
3import java.io.IOException;
4import java.util.Collections;
5import java.util.List;
6import java.util.Random;
7
8public class FraudDetectionApp {
9 private static FraudDetectionModel fraudDetectionModel;
10
11 public static void main(String[] args) {
12 MongoDBConnector mongoDBConnector = new MongoDBConnector();
13 TransactionRepository transactionRepository = new TransactionRepository(mongoDBConnector);
14 DataPreprocessor preprocessor = new DataPreprocessor(transactionRepository);
15
16 try {
17 // Load and prepare training data, and add to MongoDB
18 preprocessor.loadData("src/main/resources/creditcard.csv");
19 List<Transaction> transactions = transactionRepository.getAllTransactions();
20
21 // Shuffle and split data
22 Collections.shuffle(transactions);
23 int trainSize = (int) (transactions.size() * 0.8);
24 List<Transaction> trainSet = transactions.subList(0, trainSize);
25 List<Transaction> testSet = transactions.subList(trainSize, transactions.size());
26
27 System.out.println("Train size: " + trainSet.size() + ", Test size: " + testSet.size());
28
29 // Create and train fraud detection model
30 fraudDetectionModel = new FraudDetectionModel();
31 fraudDetectionModel.trainModel(trainSet);
32
33 // Evaluate model
34 fraudDetectionModel.evaluateModel(testSet);
35
36 } catch (IOException e) {
37 e.printStackTrace();
38 }
39 }
40}
Aqui, carregamos nossos dados em nosso banco de dados MongoDB . Em seguida, embaralhamos e divisão nossos dados em dados de treinamento e teste. Em seguida, formamos nosso modelo e testamos o modelo com os dados restantes.
Agora, vamos construir e executar nosso aplicação e dar uma olhada em nossos resultados.

Entendendo nossas métricas de avaliação

Vamos entender o que significa nosso resultado e o que podemos fazer com essas informações.
1========================Evaluation Metrics========================
2 # of classes: 2
3 Accuracy: 0.8441
4 Precision: 0.0011
5 Recall: 0.1053
6 F1 Score: 0.0022
7Precision, recall & F1: reported for positive class (class 1 - "1") only
  • Precisão: Proporção de todas as predições corretas.
1double calculateAccuracy(int TP, int TN, int FP, int FN) {
2 int totalPredictions = TP + TN + FP + FN;
3 double accuracy = (double)(TP + TN) / totalPredictions;
4 return accuracy;
5}
  • Precisão: proporção de previsões positiva e corretas.
1double calculatePrecision(int TP, int FP) {
2 double precision = (double) TP / (TP + FP);
3 return precision;
4}
  • Lembrete: proporção de positivos reais previstos corretamente.
1double calculateRecall(int TP, int FN) {
2 double recall = (double) TP / (TP + FN);
3 return recall;
4}
  • Pontuação F1 : média hermética de Precisão e Remoção.
1double calculateF1Score(double precision, double recall) {
2 double f1Score = 2 * (precision * recall) / (precision + recall);
3 return f1Score;
4}
Se você vir, a precisão é relativamente alta, mas por que a precisão é tão baixa? Bem, é porque podemos classificar a maioria dos resultados como não maliciosos e ainda estar corretos, mesmo que seja uma classificação mal informada.
A matriz de confusão resume o desempenho de um modelo de classificação mostrando as contagens de classificações reais versus previstas. Veja como interpretar a matriz de confusão fornecida:
1=========================Confusion Matrix=========================
2 0 1
3-----------
4 9615 1759 | 0 = 0
5 17 2 | 1 = 1
6
7Confusion matrix format: Actual (rowClass) predicted as (columnClass) N times
8==================================================================
  • As linhas representam as classes reais (verdade subjacente):
    • A primeira linha corresponde à classe real 0.
    • A segunda linha corresponde à classe real 1.
  • As colunas representam as classes previstas:
    • A primeira coluna corresponde à classe previstas 0.
    • A segunda coluna corresponde à classe previstas 1.

Valores

  • 9615: número de vezes que o modelo previu corretamente 0 quando a classe real era 0 (verdadeiros negativos, TL)
  • 1759: número de vezes que o modelo previu incorretamente1 quando a classe real era 0 (falsos positivos, FP)
  • 17: número de vezes que o modelo previu incorretamente 0 quando a classe real era 1 (falsos negativos, FN)
  • 2: número de vezes que o modelo previu corretamente 1 quando a classe real era 1 (verdadeiros positivos, TP)
Esta matriz de confusão sugere que o modelo é fortemente enviesado para a classe 0, já que a maioria das predições é para a classe 0. Isso é comum em conjuntos de dados com grave desequilíbrio de classe .
Bem, quando me reordenei para usar todos os recursos disponíveis, terminei com uma precisão muito melhor! Como compensação, ele não declara nenhuma transação como fraude. Não é muito útil. Então, como isso pode ser resolvido?
1========================Evaluation Metrics========================
2 # of classes: 2
3 Accuracy: 0.9986
4 Precision: 0.0000
5 Recall: 0.0000
6 F1 Score: 0.0000
7Precision, recall & F1: reported for positive class (class 1 - "1") only
8
9Warning: 1 class was never predicted by the model and was excluded from average precision
10Classes excluded from average precision: [1]
11
12=========================Confusion Matrix=========================
13 0 1
14-------------
15 11377 0 | 0 = 0
16 16 0 | 1 = 1
17
18Confusion matrix format: Actual (rowClass) predicted as (columnClass) N times
19==================================================================
Bem, realmente abordando a maneira como estamos lidando com nossos dados. Precisamos entender as diferenças entre nossas transações enganosas e as credíveis. Há uma série de pesquisas sobre esse tópico, e você pode aprender mais sobre alguns dos métodos em artigos como O que são dados desequilibrados e Como lidar com isso.

Conclusão

Neste tutorial, percorremos as etapas para criar um sistema de detecção de fraudes usando o deep learning4J e o MongoDB. A IA é um tópico popular e, usando o Kernel4j, você pode integrar a IA aos seus aplicativos Java .
Se quiser saber mais sobre o que você pode fazer com o MongoDB e o Java no mundo da IA, consulte Como implantar o Vector Search, o Atlas Search e os nós de pesquisa com o Atlas Kubernetes Operator.
Principais comentários nos fóruns
Ainda não há comentários sobre este artigo.
Iniciar a conversa

Ícone do FacebookÍcone do Twitterícone do linkedin
Avalie esse Tutorial
star-empty
star-empty
star-empty
star-empty
star-empty
Relacionado
Artigo

Orquestração do MongoDB com o Spring e Atlas Kubernetes Operator


Jun 12, 2024 | 13 min read
Tutorial

Introdução ao MongoDB e ao AWS Codewhisperer


Sep 26, 2024 | 3 min read
Tutorial

API de pesquisa de texto completo facetada Java usando o Atlas Search


Jan 17, 2025 | 18 min read
Tutorial

Construindo um Painel de Vendas Dinâmico e em Tempo Real no MongoDB


Mar 13, 2025 | 7 min read
Sumário