32
32
import com .google .cloud .vertexai .api .PredictionServiceClient ;
33
33
import com .google .cloud .vertexai .api .PredictionServiceSettings ;
34
34
import com .google .common .base .Strings ;
35
+ import com .google .common .base .Supplier ;
36
+ import com .google .common .base .Suppliers ;
35
37
import com .google .common .collect .ImmutableList ;
38
+ import com .google .errorprone .annotations .CanIgnoreReturnValue ;
36
39
import java .io .IOException ;
37
40
import java .util .List ;
38
41
import java .util .Optional ;
39
- import java .util .concurrent .locks .ReentrantLock ;
40
42
import java .util .logging .Level ;
41
43
import java .util .logging .Logger ;
42
44
@@ -61,13 +63,12 @@ public class VertexAI implements AutoCloseable {
61
63
private final String apiEndpoint ;
62
64
private final Transport transport ;
63
65
private final CredentialsProvider credentialsProvider ;
64
- private final ReentrantLock lock = new ReentrantLock ();
65
- // The clients will be instantiated lazily
66
- private Optional <PredictionServiceClient > predictionServiceClient = Optional .empty ();
67
- private Optional <LlmUtilityServiceClient > llmUtilityClient = Optional .empty ();
66
+
67
+ private final transient Supplier <PredictionServiceClient > predictionClientSupplier ;
68
+ private final transient Supplier <LlmUtilityServiceClient > llmClientSupplier ;
68
69
69
70
/**
70
- * Construct a VertexAI instance.
71
+ * Constructs a VertexAI instance.
71
72
*
72
73
* @param projectId the default project to use when making API calls
73
74
* @param location the default location to use when making API calls
@@ -78,8 +79,10 @@ public VertexAI(String projectId, String location) {
78
79
location ,
79
80
Transport .GRPC ,
80
81
ImmutableList .of (),
81
- Optional .empty (),
82
- Optional .empty ());
82
+ /* credentials= */ Optional .empty (),
83
+ /* apiEndpoint= */ Optional .empty (),
84
+ /* predictionClientSupplierOpt= */ Optional .empty (),
85
+ /* llmClientSupplierOpt= */ Optional .empty ());
83
86
}
84
87
85
88
private VertexAI (
@@ -88,7 +91,9 @@ private VertexAI(
88
91
Transport transport ,
89
92
List <String > scopes ,
90
93
Optional <Credentials > credentials ,
91
- Optional <String > apiEndpoint ) {
94
+ Optional <String > apiEndpoint ,
95
+ Optional <Supplier <PredictionServiceClient >> predictionClientSupplierOpt ,
96
+ Optional <Supplier <LlmUtilityServiceClient >> llmClientSupplierOpt ) {
92
97
if (!scopes .isEmpty () && credentials .isPresent ()) {
93
98
throw new IllegalArgumentException (
94
99
"At most one of Credentials and scopes should be specified." );
@@ -113,9 +118,19 @@ private VertexAI(
113
118
.build ();
114
119
}
115
120
121
+ this .predictionClientSupplier =
122
+ Suppliers .memoize (predictionClientSupplierOpt .orElse (this ::newPredictionServiceClient ));
123
+
124
+ this .llmClientSupplier =
125
+ Suppliers .memoize (llmClientSupplierOpt .orElse (this ::newLlmUtilityClient ));
126
+
116
127
this .apiEndpoint = apiEndpoint .orElse (String .format ("%s-aiplatform.googleapis.com" , location ));
117
128
}
118
129
130
+ public static Builder builder () {
131
+ return new Builder ();
132
+ }
133
+
119
134
/** Builder for {@link VertexAI}. */
120
135
public static class Builder {
121
136
private String projectId ;
@@ -125,11 +140,25 @@ public static class Builder {
125
140
private Optional <Credentials > credentials = Optional .empty ();
126
141
private Optional <String > apiEndpoint = Optional .empty ();
127
142
143
+ private Supplier <PredictionServiceClient > predictionClientSupplier ;
144
+
145
+ private Supplier <LlmUtilityServiceClient > llmClientSupplier ;
146
+
147
+ Builder () {}
148
+
128
149
public VertexAI build () {
129
150
checkNotNull (projectId , "projectId must be set." );
130
151
checkNotNull (location , "location must be set." );
131
152
132
- return new VertexAI (projectId , location , transport , scopes , credentials , apiEndpoint );
153
+ return new VertexAI (
154
+ projectId ,
155
+ location ,
156
+ transport ,
157
+ scopes ,
158
+ credentials ,
159
+ apiEndpoint ,
160
+ Optional .ofNullable (predictionClientSupplier ),
161
+ Optional .ofNullable (llmClientSupplier ));
133
162
}
134
163
135
164
public Builder setProjectId (String projectId ) {
@@ -167,6 +196,19 @@ public Builder setCredentials(Credentials credentials) {
167
196
return this ;
168
197
}
169
198
199
+ @ CanIgnoreReturnValue
200
+ public Builder setPredictionClientSupplier (
201
+ Supplier <PredictionServiceClient > predictionClientSupplier ) {
202
+ this .predictionClientSupplier = predictionClientSupplier ;
203
+ return this ;
204
+ }
205
+
206
+ @ CanIgnoreReturnValue
207
+ public Builder setLlmClientSupplier (Supplier <LlmUtilityServiceClient > llmClientSupplier ) {
208
+ this .llmClientSupplier = llmClientSupplier ;
209
+ return this ;
210
+ }
211
+
170
212
public Builder setScopes (List <String > scopes ) {
171
213
checkNotNull (scopes , "scopes can't be null" );
172
214
@@ -228,25 +270,23 @@ public Credentials getCredentials() throws IOException {
228
270
* method calls that map to the API methods.
229
271
*/
230
272
@ InternalApi
231
- public PredictionServiceClient getPredictionServiceClient () throws IOException {
232
- if (predictionServiceClient .isPresent ()) {
233
- return predictionServiceClient .get ();
234
- }
235
- lock .lock ();
273
+ public PredictionServiceClient getPredictionServiceClient () {
274
+ return predictionClientSupplier .get ();
275
+ }
276
+
277
+ private PredictionServiceClient newPredictionServiceClient () {
278
+ // Disable the warning message logged in getApplicationDefault
279
+ Logger defaultCredentialsProviderLogger =
280
+ Logger .getLogger ("com.google.auth.oauth2.DefaultCredentialsProvider" );
281
+ Level previousLevel = defaultCredentialsProviderLogger .getLevel ();
282
+ defaultCredentialsProviderLogger .setLevel (Level .SEVERE );
283
+
236
284
try {
237
- if (!predictionServiceClient .isPresent ()) {
238
- PredictionServiceSettings settings = getPredictionServiceSettings ();
239
- // Disable the warning message logged in getApplicationDefault
240
- Logger defaultCredentialsProviderLogger =
241
- Logger .getLogger ("com.google.auth.oauth2.DefaultCredentialsProvider" );
242
- Level previousLevel = defaultCredentialsProviderLogger .getLevel ();
243
- defaultCredentialsProviderLogger .setLevel (Level .SEVERE );
244
- predictionServiceClient = Optional .of (PredictionServiceClient .create (settings ));
245
- defaultCredentialsProviderLogger .setLevel (previousLevel );
246
- }
247
- return predictionServiceClient .get ();
285
+ return PredictionServiceClient .create (getPredictionServiceSettings ());
286
+ } catch (IOException e ) {
287
+ throw new IllegalStateException (e );
248
288
} finally {
249
- lock . unlock ( );
289
+ defaultCredentialsProviderLogger . setLevel ( previousLevel );
250
290
}
251
291
}
252
292
@@ -257,8 +297,8 @@ private PredictionServiceSettings getPredictionServiceSettings() throws IOExcept
257
297
} else {
258
298
builder = PredictionServiceSettings .newBuilder ();
259
299
}
260
- builder .setEndpoint (String .format ("%s:443" , this . apiEndpoint ));
261
- builder .setCredentialsProvider (this . credentialsProvider );
300
+ builder .setEndpoint (String .format ("%s:443" , apiEndpoint ));
301
+ builder .setCredentialsProvider (credentialsProvider );
262
302
263
303
HeaderProvider headerProvider =
264
304
FixedHeaderProvider .create (
@@ -279,25 +319,23 @@ private PredictionServiceSettings getPredictionServiceSettings() throws IOExcept
279
319
* calls that map to the API methods.
280
320
*/
281
321
@ InternalApi
282
- public LlmUtilityServiceClient getLlmUtilityClient () throws IOException {
283
- if (llmUtilityClient .isPresent ()) {
284
- return llmUtilityClient .get ();
285
- }
286
- lock .lock ();
322
+ public LlmUtilityServiceClient getLlmUtilityClient () {
323
+ return llmClientSupplier .get ();
324
+ }
325
+
326
+ private LlmUtilityServiceClient newLlmUtilityClient () {
327
+ // Disable the warning message logged in getApplicationDefault
328
+ Logger defaultCredentialsProviderLogger =
329
+ Logger .getLogger ("com.google.auth.oauth2.DefaultCredentialsProvider" );
330
+ Level previousLevel = defaultCredentialsProviderLogger .getLevel ();
331
+ defaultCredentialsProviderLogger .setLevel (Level .SEVERE );
332
+
287
333
try {
288
- if (!llmUtilityClient .isPresent ()) {
289
- LlmUtilityServiceSettings settings = getLlmUtilityServiceClientSettings ();
290
- // Disable the warning message logged in getApplicationDefault
291
- Logger defaultCredentialsProviderLogger =
292
- Logger .getLogger ("com.google.auth.oauth2.DefaultCredentialsProvider" );
293
- Level previousLevel = defaultCredentialsProviderLogger .getLevel ();
294
- defaultCredentialsProviderLogger .setLevel (Level .SEVERE );
295
- llmUtilityClient = Optional .of (LlmUtilityServiceClient .create (settings ));
296
- defaultCredentialsProviderLogger .setLevel (previousLevel );
297
- }
298
- return llmUtilityClient .get ();
334
+ return LlmUtilityServiceClient .create (getLlmUtilityServiceClientSettings ());
335
+ } catch (IOException e ) {
336
+ throw new IllegalStateException (e );
299
337
} finally {
300
- lock . unlock ( );
338
+ defaultCredentialsProviderLogger . setLevel ( previousLevel );
301
339
}
302
340
}
303
341
@@ -308,8 +346,8 @@ private LlmUtilityServiceSettings getLlmUtilityServiceClientSettings() throws IO
308
346
} else {
309
347
settingsBuilder = LlmUtilityServiceSettings .newBuilder ();
310
348
}
311
- settingsBuilder .setEndpoint (String .format ("%s:443" , this . apiEndpoint ));
312
- settingsBuilder .setCredentialsProvider (this . credentialsProvider );
349
+ settingsBuilder .setEndpoint (String .format ("%s:443" , apiEndpoint ));
350
+ settingsBuilder .setCredentialsProvider (credentialsProvider );
313
351
314
352
HeaderProvider headerProvider =
315
353
FixedHeaderProvider .create (
@@ -325,11 +363,7 @@ private LlmUtilityServiceSettings getLlmUtilityServiceClientSettings() throws IO
325
363
/** Closes the VertexAI instance together with all its instantiated clients. */
326
364
@ Override
327
365
public void close () {
328
- if (predictionServiceClient .isPresent ()) {
329
- predictionServiceClient .get ().close ();
330
- }
331
- if (llmUtilityClient .isPresent ()) {
332
- llmUtilityClient .get ().close ();
333
- }
366
+ predictionClientSupplier .get ().close ();
367
+ llmClientSupplier .get ().close ();
334
368
}
335
369
}
0 commit comments