Skip to content

Commit 646c2b4

Browse files
author
Praful Makani
authored
feat: expose location field of model (#175)
1 parent 5212b2f commit 646c2b4

File tree

3 files changed

+31
-0
lines changed

3 files changed

+31
-0
lines changed

google-cloud-bigquery/src/main/java/com/google/cloud/bigquery/Model.java

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,12 @@ public Builder setLabels(Map<String, String> labels) {
108108
return this;
109109
}
110110

111+
@Override
112+
Builder setLocation(String location) {
113+
infoBuilder.setLocation(location);
114+
return this;
115+
}
116+
111117
@Override
112118
Builder setTrainingRuns(List<TrainingRun> trainingRunList) {
113119
infoBuilder.setTrainingRuns(trainingRunList);

google-cloud-bigquery/src/main/java/com/google/cloud/bigquery/ModelInfo.java

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@ public Model apply(ModelInfo ModelInfo) {
6868
private final Long lastModifiedTime;
6969
private final Long expirationTime;
7070
private final Labels labels;
71+
private final String location;
7172
private final ImmutableList<TrainingRun> trainingRunList;
7273
private final ImmutableList<StandardSQLField> featureColumnList;
7374
private final ImmutableList<StandardSQLField> labelColumnList;
@@ -97,6 +98,8 @@ public abstract static class Builder {
9798
*/
9899
public abstract Builder setLabels(Map<String, String> labels);
99100

101+
abstract Builder setLocation(String location);
102+
100103
public abstract Builder setModelId(ModelId modelId);
101104

102105
abstract Builder setEtag(String etag);
@@ -130,6 +133,7 @@ static class BuilderImpl extends Builder {
130133
private Long lastModifiedTime;
131134
private Long expirationTime;
132135
private Labels labels = Labels.ZERO;
136+
private String location;
133137
private List<TrainingRun> trainingRunList = Collections.emptyList();
134138
private List<StandardSQLField> labelColumnList = Collections.emptyList();
135139
private List<StandardSQLField> featureColumnList = Collections.emptyList();
@@ -150,6 +154,7 @@ static class BuilderImpl extends Builder {
150154
this.labelColumnList = modelInfo.labelColumnList;
151155
this.featureColumnList = modelInfo.featureColumnList;
152156
this.encryptionConfiguration = modelInfo.encryptionConfiguration;
157+
this.location = modelInfo.location;
153158
}
154159

155160
BuilderImpl(Model modelPb) {
@@ -165,6 +170,7 @@ static class BuilderImpl extends Builder {
165170
this.lastModifiedTime = modelPb.getLastModifiedTime();
166171
this.expirationTime = modelPb.getExpirationTime();
167172
this.labels = Labels.fromPb(modelPb.getLabels());
173+
this.location = modelPb.getLocation();
168174
if (modelPb.getTrainingRuns() != null) {
169175
this.trainingRunList = modelPb.getTrainingRuns();
170176
}
@@ -236,6 +242,12 @@ public Builder setLabels(Map<String, String> labels) {
236242
return this;
237243
}
238244

245+
@Override
246+
Builder setLocation(String location) {
247+
this.location = location;
248+
return this;
249+
}
250+
239251
@Override
240252
Builder setTrainingRuns(List<TrainingRun> trainingRunList) {
241253
this.trainingRunList = checkNotNull(trainingRunList);
@@ -276,6 +288,7 @@ public ModelInfo build() {
276288
this.lastModifiedTime = builder.lastModifiedTime;
277289
this.expirationTime = builder.expirationTime;
278290
this.labels = builder.labels;
291+
this.location = builder.location;
279292
this.trainingRunList = ImmutableList.copyOf(builder.trainingRunList);
280293
this.labelColumnList = ImmutableList.copyOf(builder.labelColumnList);
281294
this.featureColumnList = ImmutableList.copyOf(builder.featureColumnList);
@@ -330,6 +343,11 @@ public Map<String, String> getLabels() {
330343
return labels.userMap();
331344
}
332345

346+
/** Returns a location of the model. */
347+
public String getLocation() {
348+
return location;
349+
}
350+
333351
/** Returns metadata about each training run iteration. */
334352
@BetaApi
335353
public ImmutableList<TrainingRun> getTrainingRuns() {
@@ -368,6 +386,7 @@ public String toString() {
368386
.add("lastModifiedTime", lastModifiedTime)
369387
.add("expirationTime", expirationTime)
370388
.add("labels", labels)
389+
.add("location", location)
371390
.add("trainingRuns", trainingRunList)
372391
.add("labelColumns", labelColumnList)
373392
.add("featureColumns", featureColumnList)
@@ -416,6 +435,7 @@ Model toPb() {
416435
modelPb.setLastModifiedTime(lastModifiedTime);
417436
modelPb.setExpirationTime(expirationTime);
418437
modelPb.setLabels(labels.toPb());
438+
modelPb.setLocation(location);
419439
modelPb.setTrainingRuns(trainingRunList);
420440
if (labelColumnList != null) {
421441
modelPb.setLabelColumns(Lists.transform(labelColumnList, StandardSQLField.TO_PB_FUNCTION));

google-cloud-bigquery/src/test/java/com/google/cloud/bigquery/ModelInfoTest.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ public class ModelInfoTest {
3333
private static final Long EXPIRATION_TIME = 30L;
3434
private static final String DESCRIPTION = "description";
3535
private static final String FRIENDLY_NAME = "friendlyname";
36+
private static final String LOCATION = "US";
3637
private static final EncryptionConfiguration MODEL_ENCRYPTION_CONFIGURATION =
3738
EncryptionConfiguration.newBuilder().setKmsKeyName("KMS_KEY_1").build();
3839

@@ -52,6 +53,7 @@ public class ModelInfoTest {
5253
.setFriendlyName(FRIENDLY_NAME)
5354
.setTrainingRuns(TRAINING_RUN_LIST)
5455
.setEncryptionConfiguration(MODEL_ENCRYPTION_CONFIGURATION)
56+
.setLocation(LOCATION)
5557
.build();
5658

5759
@Test
@@ -75,6 +77,7 @@ public void testBuilder() {
7577
assertEquals(FRIENDLY_NAME, MODEL_INFO.getFriendlyName());
7678
assertEquals(TRAINING_OPTIONS, MODEL_INFO.getTrainingRuns().get(0).getTrainingOptions());
7779
assertEquals(MODEL_ENCRYPTION_CONFIGURATION, MODEL_INFO.getEncryptionConfiguration());
80+
assertEquals(LOCATION, MODEL_INFO.getLocation());
7881
}
7982

8083
@Test
@@ -88,6 +91,7 @@ public void testOf() {
8891
assertNull(modelInfo.getDescription());
8992
assertNull(modelInfo.getFriendlyName());
9093
assertNull(modelInfo.getEncryptionConfiguration());
94+
assertNull(modelInfo.getLocation());
9195
assertEquals(modelInfo.getTrainingRuns().isEmpty(), true);
9296
assertEquals(modelInfo.getLabelColumns().isEmpty(), true);
9397
assertEquals(modelInfo.getFeatureColumns().isEmpty(), true);
@@ -113,6 +117,7 @@ private void compareModelInfo(ModelInfo expected, ModelInfo value) {
113117
assertEquals(expected.getDescription(), value.getDescription());
114118
assertEquals(expected.getFriendlyName(), value.getFriendlyName());
115119
assertEquals(expected.getLabels(), value.getLabels());
120+
assertEquals(expected.getLocation(), value.getLocation());
116121
assertEquals(expected.hashCode(), value.hashCode());
117122
assertEquals(expected.getTrainingRuns(), value.getTrainingRuns());
118123
assertEquals(expected.getLabelColumns(), value.getLabelColumns());

0 commit comments

Comments
 (0)