Skip to content

Commit 7bdfa55

Browse files
fix: Simplify VertexAI with Suppliers.memorize and avoid accessing private members in tests. (#10694)
- Implement lazy init using Suppliers.memorize instead of an explicit lock. - Add a newBuilder method in VertexAI. - Updates unit tests to avoid accessing private fields in VertexAI. PiperOrigin-RevId: 624303836 Co-authored-by: A Vertex SDK engineer <[email protected]>
1 parent ae22f1c commit 7bdfa55

File tree

4 files changed

+103
-173
lines changed

4 files changed

+103
-173
lines changed

java-vertexai/google-cloud-vertexai/src/main/java/com/google/cloud/vertexai/VertexAI.java

Lines changed: 88 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -32,11 +32,13 @@
3232
import com.google.cloud.vertexai.api.PredictionServiceClient;
3333
import com.google.cloud.vertexai.api.PredictionServiceSettings;
3434
import com.google.common.base.Strings;
35+
import com.google.common.base.Supplier;
36+
import com.google.common.base.Suppliers;
3537
import com.google.common.collect.ImmutableList;
38+
import com.google.errorprone.annotations.CanIgnoreReturnValue;
3639
import java.io.IOException;
3740
import java.util.List;
3841
import java.util.Optional;
39-
import java.util.concurrent.locks.ReentrantLock;
4042
import java.util.logging.Level;
4143
import java.util.logging.Logger;
4244

@@ -61,13 +63,12 @@ public class VertexAI implements AutoCloseable {
6163
private final String apiEndpoint;
6264
private final Transport transport;
6365
private final CredentialsProvider credentialsProvider;
64-
private final ReentrantLock lock = new ReentrantLock();
65-
// The clients will be instantiated lazily
66-
private Optional<PredictionServiceClient> predictionServiceClient = Optional.empty();
67-
private Optional<LlmUtilityServiceClient> llmUtilityClient = Optional.empty();
66+
67+
private final transient Supplier<PredictionServiceClient> predictionClientSupplier;
68+
private final transient Supplier<LlmUtilityServiceClient> llmClientSupplier;
6869

6970
/**
70-
* Construct a VertexAI instance.
71+
* Constructs a VertexAI instance.
7172
*
7273
* @param projectId the default project to use when making API calls
7374
* @param location the default location to use when making API calls
@@ -78,8 +79,10 @@ public VertexAI(String projectId, String location) {
7879
location,
7980
Transport.GRPC,
8081
ImmutableList.of(),
81-
Optional.empty(),
82-
Optional.empty());
82+
/* credentials= */ Optional.empty(),
83+
/* apiEndpoint= */ Optional.empty(),
84+
/* predictionClientSupplierOpt= */ Optional.empty(),
85+
/* llmClientSupplierOpt= */ Optional.empty());
8386
}
8487

8588
private VertexAI(
@@ -88,7 +91,9 @@ private VertexAI(
8891
Transport transport,
8992
List<String> scopes,
9093
Optional<Credentials> credentials,
91-
Optional<String> apiEndpoint) {
94+
Optional<String> apiEndpoint,
95+
Optional<Supplier<PredictionServiceClient>> predictionClientSupplierOpt,
96+
Optional<Supplier<LlmUtilityServiceClient>> llmClientSupplierOpt) {
9297
if (!scopes.isEmpty() && credentials.isPresent()) {
9398
throw new IllegalArgumentException(
9499
"At most one of Credentials and scopes should be specified.");
@@ -113,9 +118,19 @@ private VertexAI(
113118
.build();
114119
}
115120

121+
this.predictionClientSupplier =
122+
Suppliers.memoize(predictionClientSupplierOpt.orElse(this::newPredictionServiceClient));
123+
124+
this.llmClientSupplier =
125+
Suppliers.memoize(llmClientSupplierOpt.orElse(this::newLlmUtilityClient));
126+
116127
this.apiEndpoint = apiEndpoint.orElse(String.format("%s-aiplatform.googleapis.com", location));
117128
}
118129

130+
public static Builder builder() {
131+
return new Builder();
132+
}
133+
119134
/** Builder for {@link VertexAI}. */
120135
public static class Builder {
121136
private String projectId;
@@ -125,11 +140,25 @@ public static class Builder {
125140
private Optional<Credentials> credentials = Optional.empty();
126141
private Optional<String> apiEndpoint = Optional.empty();
127142

143+
private Supplier<PredictionServiceClient> predictionClientSupplier;
144+
145+
private Supplier<LlmUtilityServiceClient> llmClientSupplier;
146+
147+
Builder() {}
148+
128149
public VertexAI build() {
129150
checkNotNull(projectId, "projectId must be set.");
130151
checkNotNull(location, "location must be set.");
131152

132-
return new VertexAI(projectId, location, transport, scopes, credentials, apiEndpoint);
153+
return new VertexAI(
154+
projectId,
155+
location,
156+
transport,
157+
scopes,
158+
credentials,
159+
apiEndpoint,
160+
Optional.ofNullable(predictionClientSupplier),
161+
Optional.ofNullable(llmClientSupplier));
133162
}
134163

135164
public Builder setProjectId(String projectId) {
@@ -167,6 +196,19 @@ public Builder setCredentials(Credentials credentials) {
167196
return this;
168197
}
169198

199+
@CanIgnoreReturnValue
200+
public Builder setPredictionClientSupplier(
201+
Supplier<PredictionServiceClient> predictionClientSupplier) {
202+
this.predictionClientSupplier = predictionClientSupplier;
203+
return this;
204+
}
205+
206+
@CanIgnoreReturnValue
207+
public Builder setLlmClientSupplier(Supplier<LlmUtilityServiceClient> llmClientSupplier) {
208+
this.llmClientSupplier = llmClientSupplier;
209+
return this;
210+
}
211+
170212
public Builder setScopes(List<String> scopes) {
171213
checkNotNull(scopes, "scopes can't be null");
172214

@@ -228,25 +270,23 @@ public Credentials getCredentials() throws IOException {
228270
* method calls that map to the API methods.
229271
*/
230272
@InternalApi
231-
public PredictionServiceClient getPredictionServiceClient() throws IOException {
232-
if (predictionServiceClient.isPresent()) {
233-
return predictionServiceClient.get();
234-
}
235-
lock.lock();
273+
public PredictionServiceClient getPredictionServiceClient() {
274+
return predictionClientSupplier.get();
275+
}
276+
277+
private PredictionServiceClient newPredictionServiceClient() {
278+
// Disable the warning message logged in getApplicationDefault
279+
Logger defaultCredentialsProviderLogger =
280+
Logger.getLogger("com.google.auth.oauth2.DefaultCredentialsProvider");
281+
Level previousLevel = defaultCredentialsProviderLogger.getLevel();
282+
defaultCredentialsProviderLogger.setLevel(Level.SEVERE);
283+
236284
try {
237-
if (!predictionServiceClient.isPresent()) {
238-
PredictionServiceSettings settings = getPredictionServiceSettings();
239-
// Disable the warning message logged in getApplicationDefault
240-
Logger defaultCredentialsProviderLogger =
241-
Logger.getLogger("com.google.auth.oauth2.DefaultCredentialsProvider");
242-
Level previousLevel = defaultCredentialsProviderLogger.getLevel();
243-
defaultCredentialsProviderLogger.setLevel(Level.SEVERE);
244-
predictionServiceClient = Optional.of(PredictionServiceClient.create(settings));
245-
defaultCredentialsProviderLogger.setLevel(previousLevel);
246-
}
247-
return predictionServiceClient.get();
285+
return PredictionServiceClient.create(getPredictionServiceSettings());
286+
} catch (IOException e) {
287+
throw new IllegalStateException(e);
248288
} finally {
249-
lock.unlock();
289+
defaultCredentialsProviderLogger.setLevel(previousLevel);
250290
}
251291
}
252292

@@ -257,8 +297,8 @@ private PredictionServiceSettings getPredictionServiceSettings() throws IOExcept
257297
} else {
258298
builder = PredictionServiceSettings.newBuilder();
259299
}
260-
builder.setEndpoint(String.format("%s:443", this.apiEndpoint));
261-
builder.setCredentialsProvider(this.credentialsProvider);
300+
builder.setEndpoint(String.format("%s:443", apiEndpoint));
301+
builder.setCredentialsProvider(credentialsProvider);
262302

263303
HeaderProvider headerProvider =
264304
FixedHeaderProvider.create(
@@ -279,25 +319,23 @@ private PredictionServiceSettings getPredictionServiceSettings() throws IOExcept
279319
* calls that map to the API methods.
280320
*/
281321
@InternalApi
282-
public LlmUtilityServiceClient getLlmUtilityClient() throws IOException {
283-
if (llmUtilityClient.isPresent()) {
284-
return llmUtilityClient.get();
285-
}
286-
lock.lock();
322+
public LlmUtilityServiceClient getLlmUtilityClient() {
323+
return llmClientSupplier.get();
324+
}
325+
326+
private LlmUtilityServiceClient newLlmUtilityClient() {
327+
// Disable the warning message logged in getApplicationDefault
328+
Logger defaultCredentialsProviderLogger =
329+
Logger.getLogger("com.google.auth.oauth2.DefaultCredentialsProvider");
330+
Level previousLevel = defaultCredentialsProviderLogger.getLevel();
331+
defaultCredentialsProviderLogger.setLevel(Level.SEVERE);
332+
287333
try {
288-
if (!llmUtilityClient.isPresent()) {
289-
LlmUtilityServiceSettings settings = getLlmUtilityServiceClientSettings();
290-
// Disable the warning message logged in getApplicationDefault
291-
Logger defaultCredentialsProviderLogger =
292-
Logger.getLogger("com.google.auth.oauth2.DefaultCredentialsProvider");
293-
Level previousLevel = defaultCredentialsProviderLogger.getLevel();
294-
defaultCredentialsProviderLogger.setLevel(Level.SEVERE);
295-
llmUtilityClient = Optional.of(LlmUtilityServiceClient.create(settings));
296-
defaultCredentialsProviderLogger.setLevel(previousLevel);
297-
}
298-
return llmUtilityClient.get();
334+
return LlmUtilityServiceClient.create(getLlmUtilityServiceClientSettings());
335+
} catch (IOException e) {
336+
throw new IllegalStateException(e);
299337
} finally {
300-
lock.unlock();
338+
defaultCredentialsProviderLogger.setLevel(previousLevel);
301339
}
302340
}
303341

@@ -308,8 +346,8 @@ private LlmUtilityServiceSettings getLlmUtilityServiceClientSettings() throws IO
308346
} else {
309347
settingsBuilder = LlmUtilityServiceSettings.newBuilder();
310348
}
311-
settingsBuilder.setEndpoint(String.format("%s:443", this.apiEndpoint));
312-
settingsBuilder.setCredentialsProvider(this.credentialsProvider);
349+
settingsBuilder.setEndpoint(String.format("%s:443", apiEndpoint));
350+
settingsBuilder.setCredentialsProvider(credentialsProvider);
313351

314352
HeaderProvider headerProvider =
315353
FixedHeaderProvider.create(
@@ -325,11 +363,7 @@ private LlmUtilityServiceSettings getLlmUtilityServiceClientSettings() throws IO
325363
/** Closes the VertexAI instance together with all its instantiated clients. */
326364
@Override
327365
public void close() {
328-
if (predictionServiceClient.isPresent()) {
329-
predictionServiceClient.get().close();
330-
}
331-
if (llmUtilityClient.isPresent()) {
332-
llmUtilityClient.get().close();
333-
}
366+
predictionClientSupplier.get().close();
367+
llmClientSupplier.get().close();
334368
}
335369
}

java-vertexai/google-cloud-vertexai/src/test/java/com/google/cloud/vertexai/generativeai/ChatSessionTest.java

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -41,11 +41,9 @@
4141
import com.google.cloud.vertexai.api.Tool;
4242
import com.google.cloud.vertexai.api.Type;
4343
import java.io.IOException;
44-
import java.lang.reflect.Field;
4544
import java.util.Arrays;
4645
import java.util.Iterator;
4746
import java.util.List;
48-
import java.util.Optional;
4947
import org.junit.Before;
5048
import org.junit.Rule;
5149
import org.junit.Test;
@@ -309,12 +307,15 @@ public void sendMessageWithText_throwsIllegalStateExceptionWhenFinishReasonIsNot
309307
public void testChatSessionMergeHistoryToRootChatSession() throws Exception {
310308

311309
// (Arrange) Set up the return value of the generateContent
312-
VertexAI vertexAi = new VertexAI(PROJECT, LOCATION);
313-
GenerativeModel model = new GenerativeModel("gemini-pro", vertexAi);
314310

315-
Field field = VertexAI.class.getDeclaredField("predictionServiceClient");
316-
field.setAccessible(true);
317-
field.set(vertexAi, Optional.of(mockPredictionServiceClient));
311+
VertexAI vertexAi =
312+
VertexAI.builder()
313+
.setProjectId(PROJECT)
314+
.setLocation(LOCATION)
315+
.setPredictionClientSupplier(() -> mockPredictionServiceClient)
316+
.build();
317+
318+
GenerativeModel model = new GenerativeModel("gemini-pro", vertexAi);
318319

319320
when(mockPredictionServiceClient.generateContentCallable()).thenReturn(mockUnaryCallable);
320321
when(mockUnaryCallable.call(any(GenerateContentRequest.class)))

0 commit comments

Comments
 (0)