Speeding Up Your AI-powered Search with JAI Async
Speeding Up Your AI-powered Search with JAI Async
Introduction
Businesses today must deal with the massive amount of data constantly being generated. The data collected from various sources enables organizations to understand their customers' needs and preferences better. With so much data available, the search process can become time-consuming and slow down business operations. A more efficient search system is essential to support business success. Our previous developer notebook discussed how combining ChatGPT with retrieval and re-ranking methods can improve search accuracy. You can obtain a fast and efficient search function place by retrieving the most related content through cosine similarity to a hypothetical answer. By using async calls, you can even decrease the search time further, resulting in a 50% to 200% increase in speed! Steps 1 and 2 can be done in parallel with steps 3 and 4 using the async interface for Open AI API provided by JAI. In this blog, we will dive into how you can speed up your search with JAI Async methods.
Speeding things up with JAI Async
In the previous developer notebook, we discussed how to improve search accuracy by combining ChatGPT with retrieval and re-ranking methods. This technique can be implemented on top of existing search systems, including Elasticsearch, Solr, or any custom search engine application.
To implement this approach, we did these steps:
The most related content, as measured by cosine similarity to the hypothetical answer (HyDE), is retrieved using this approach. It is fast and can be added to a search function you already have without managing a vector database.
We can speed up this process by using async calls. The speed-up is about 30% to 200% faster.
Steps 1 and 2 can be parallel with steps 3 and 4. WE CAN EASILY DO THIS since JAI has an async interface for accessing Open AI API in Java.
Digging in
Let’s show an updated version of the main method within the WhoWonUFC290Async class. This code introduces asynchronous operations and uses CompletableFuture to handle asynchronous tasks. Here's a breakdown of the code:
public static void main(String... args) throws Exception {
try {
long startTime = System.currentTimeMillis();
final CountDownLatch countDownLatch = new CountDownLatch(2);
// Generating a hypothetical answer and its embedding
final var hypotheticalAnswerEmbeddingFuture = hypotheticalAnswer()
.thenCompose(WhoWonUFC290Async::embeddingsAsync).thenApply(floats -> {
countDownLatch.countDown();
return floats;
});
// Generate a list of queries and use them to look up articles.
final var queriesFuture = jsonGPT(QUERIES_INPUT.replace("{USER_QUESTION}", USER_QUESTION))
.thenApply(queriesJson ->
JsonParserBuilder.builder().build().parse(queriesJson)
.getObjectNode().getArrayNode("queries")
.filter(node -> node instanceof StringNode)
.stream().map(Object::toString).collect(Collectors.toList())
).thenCompose(WhoWonUFC290Async::getArticles
).thenApply(objectNodes -> {
countDownLatch.countDown();
return objectNodes;
});
if (!countDownLatch.await(30, TimeUnit.SECONDS))
throw new TimeoutException("Timed out waiting for hypotheticalAnswerEmbedding and " +
" articles ");
final var articles = queriesFuture.get();
final var hypotheticalAnswerEmbedding = hypotheticalAnswerEmbeddingFuture.get();
// Extracting article content and generating embeddings for each article
final var articleContent = articles.stream().map(article ->
String.format("%s %s %s", article.getString("title"),
article.getString("description"), article.getString("content").substring(0, 100)))
.collect(Collectors.toList());
final var articleEmbeddings = embeddingsAsync(articleContent).get();
// Calculating cosine similarities between the hypothetical answer embedding and article embeddings
final var cosineSimilarities = articleEmbeddings.stream()
.map(articleEmbedding -> dot(hypotheticalAnswerEmbedding, articleEmbedding))
.collect(Collectors.toList());
// Creating a set of scored articles based on cosine similarities
final var articleSet = IntStream.range(0,
Math.min(cosineSimilarities.size(), articleContent.size()))
.mapToObj(i -> new ScoredArticle(articles.get(i), cosineSimilarities.get(i)))
.collect(Collectors.toSet());
final var sortedArticles = new ArrayList<>(articleSet);
sortedArticles.sort((o1, o2) -> Float.compare(o2.getScore(), o1.getScore()));
// Printing the top 5 scored articles
sortedArticles.subList(0, 5).forEach(System.out::println);
// Formatting the top results as JSON strings
final var formattedTopResults = String.join(",\\n", sortedArticles.stream()
.map(ScoredArticle::getContent)
.map(article -> String.format(Json.niceJson("{'title':'%s', 'url':'%s'," +
" 'description':'%s', 'content':'%s'}\\n"),
article.getString("title"), article.getString("url"),
article.getString("description"),
getArticleContent(article))).collect(Collectors.toList()).subList(0, 10));
System.out.println(formattedTopResults);
// Generating the final answer with the formatted top results
final var finalAnswer = jsonGPT(ANSWER_INPUT.replace("{USER_QUESTION}", USER_QUESTION)
.replace("{formatted_top_results}", formattedTopResults), "Output format is markdown").get();
System.out.println(finalAnswer);
long endTime = System.currentTimeMillis();
System.out.println(endTime - startTime);
} catch (Exception ex) {
ex.printStackTrace();
}
}
Let’s break it down and show the async method calls
Let’s cover the rest.
Recall that we are doing these steps
Steps 1 and 2 can be parallel with steps 1 and 2. Since JAI has an async interface for accessing Open AI API in Java, we can easily do this. Let’s show steps 3 and 4 at the same time as 1 and 2. First, we will run 3 and 4 using the async interface of JAI .
领英推荐
public static CompletableFuture<String> jsonGPT(String input) {
return jsonGPT(input, "All output shall be JSON");
}
public static CompletableFuture<String> jsonGPT(String input, String system) {
final var client = OpenAIClient.builder()
.setApiKey(System.getenv("OPENAI_API_KEY")).build();
final var chatRequest = ChatRequest.builder()
.addMessage(Message.builder().role(Role.SYSTEM).content(system).build())
.addMessage(Message.builder().role(Role.USER).content(input).build())
.build();
return client.chatAsync(chatRequest).thenApply(chat -> {
if (chat.getResponse().isPresent()) {
return chat.getResponse().get().getChoices().get(0).getMessage().getContent();
} else {
System.out.println(chat.getStatusCode().orElse(666) + " " + chat.getStatusMessage().orElse(""));
throw new IllegalStateException();
}
});
}
public static CompletableFuture<String> hypotheticalAnswer() {
final var input = HA_INPUT.replace("{USER_QUESTION}", USER_QUESTION);
return jsonGPT(input).thenApply(response -> JsonParserBuilder.builder().build().parse(response).getObjectNode().getString("hypotheticalAnswer"));
}
public static CompletableFuture<float[]> embeddingsAsync(String input) {
System.out.println("INPUT " + input);
return embeddingsAsync(List.of(input)).thenApply(embeddings -> embeddings.get(0));
}
public static CompletableFuture<List<float[]>> embeddingsAsync(List<String> input) {
System.out.println("INPUT " + input);
if (input == null || input.size() == 0) {
return CompletableFuture.completedFuture(Collections.singletonList(new float[0]));
}
final var client = OpenAIClient.builder().setApiKey(System.getenv("OPENAI_API_KEY")).build();
return client.embeddingAsync(EmbeddingRequest.builder().model("text-embedding-ada-002").input(input).build()).thenApply(embedding -> {
if (embedding.getResponse().isPresent()) {
return embedding.getResponse().get().getData().stream().map(Embedding::getEmbedding).collect(Collectors.toList());
} else {
System.out.println(embedding.getStatusCode().orElse(666) + " " + embedding.getStatusMessage().orElse(""));
throw new IllegalStateException(embedding.getStatusCode().orElse(666) + " " + embedding.getStatusMessage().orElse(""));
}
});
}
public static void main(String... args) {
try {
long startTime = System.currentTimeMillis();
final CountDownLatch countDownLatch = new CountDownLatch(2);
// Generating a hypothetical answer and its embedding
final var hypotheticalAnswerEmbeddingFuture = hypotheticalAnswer()
.thenCompose(WhoWonUFC290Async::embeddingsAsync).thenApply(floats -> {
countDownLatch.countDown();
return floats;
});
...
The example adds some additional methods to support asynchronous operations and CompletableFuture usage. Here's a breakdown of the code:
Next, let’s run steps 1 and 2, getting the list of queries and then loading the articles for query, at the same time.
public static void main(String... args) {
try {
final CountDownLatch countDownLatch = new CountDownLatch(2);
// Generating a hypothetical answer and its embedding
...
// Generate a list of queries and use them to look up articles.
final var queriesFuture = jsonGPT(QUERIES_INPUT.replace("{USER_QUESTION}", USER_QUESTION))
.thenApply(queriesJson ->
JsonParserBuilder.builder().build().parse(queriesJson)
.getObjectNode().getArrayNode("queries")
.filter(node -> node instanceof StringNode)
.stream().map(Object::toString).collect(Collectors.toList())
).thenCompose(WhoWonUFC290Async::getArticles
).thenApply(objectNodes -> {
countDownLatch.countDown();
return objectNodes;
});
if (!countDownLatch.await(30, TimeUnit.SECONDS))
throw new TimeoutException("Timed out waiting for hypotheticalAnswerEmbedding and " +
" articles ");
final var articles = queriesFuture.get();
final var hypotheticalAnswerEmbedding =
hypotheticalAnswerEmbeddingFuture.get();
...
Let’s take a look at the getArticles method.
WhoWonUFC290Async::getArticles.
private static CompletableFuture<List<ObjectNode>> getArticles(List<String> queries) {
final CompletableFuture<List<ObjectNode>> completableFuture = new CompletableFuture<>();
final CountDownLatch countDownLatch = new CountDownLatch(queries.size());
final LinkedTransferQueue<ObjectNode> results = new LinkedTransferQueue<>();
final List<CompletableFuture<ArrayNode>> queryFutures = queries.stream()
.map(WhoWonUFC290Async::searchNews).collect(Collectors.toList());
final ExecutorService executorService = Executors.newCachedThreadPool();
executorService.submit(() -> {
for (CompletableFuture<ArrayNode> future : queryFutures) {
try {
ArrayNode arrayNode = future.get();
arrayNode.forEach(node -> results.add((ObjectNode) node));
countDownLatch.countDown();
} catch (Exception e) {
e.printStackTrace();
countDownLatch.countDown();
}
}
});
executorService.submit(() -> {
try {
if (!countDownLatch.await(30, TimeUnit.SECONDS)) {
throw new TimeoutException("Timed out waiting for articles");
}
final var list = new ArrayList<ObjectNode>();
if (list.addAll(results)) {
completableFuture.complete(list);
} else {
completableFuture.complete(list);
}
} catch (Exception e) {
completableFuture.completeExceptionally(e);
} finally {
executorService.shutdown();
}
});
return completableFuture;
}
The getArticles method is a utility method that retrieves articles for a given list of queries asynchronously using CompletableFuture. Let's break down the code:
This method allows the asynchronous execution of multiple queries to retrieve articles, and the results are collected in a list of ObjectNode using CompletableFuture for concurrent processing and efficient handling of asynchronous tasks.
The example continues with the remaining steps of extracting article content, calculating cosine similarities, creating scored articles, sorting them, and using the selected articles as part of the final answer. The big difference is that looking up articles based on generated queries happens simultaneously as we generate an ideal answer and create the ideal answer's embeddings.
Conclusion
The Java Open AI API Client Async methods can speed up generating a list of potential queries, executing search queries, creating an ideal answer, and creating an ideal answer embedding to measure and prioritize articles. The example code demonstrates the use of CompletableFuture to handle asynchronous operations with JAI, including generating a hypothetical answer and its embedding, obtaining a list of queries, and retrieving articles for each query. The examples, as before, also show how to calculate cosine similarities, create scored articles, and sort them to select the best articles for the final answer.
More details