Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion src/langchain_google_cloud_sql_pg/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,7 @@ def from_instance(
user: Optional[str] = None,
password: Optional[str] = None,
ip_type: Union[str, IPTypes] = IPTypes.PUBLIC,
quota_project: Optional[str] = None,
) -> PostgresEngine:
# Running a loop in a background thread allows us to support
# async methods from non-async environments
Expand All @@ -128,6 +129,7 @@ def from_instance(
password,
loop=loop,
thread=thread,
quota_project=quota_project,
)
return asyncio.run_coroutine_threadsafe(coro, loop).result()

Expand All @@ -143,6 +145,7 @@ async def _create(
password: Optional[str] = None,
loop: Optional[asyncio.AbstractEventLoop] = None,
thread: Optional[Thread] = None,
quota_project: Optional[str] = None,
) -> PostgresEngine:
if bool(user) ^ bool(password):
raise ValueError(
Expand All @@ -152,7 +155,9 @@ async def _create(
)
if cls._connector is None:
cls._connector = Connector(
loop=asyncio.get_event_loop(), user_agent=USER_AGENT
loop=asyncio.get_event_loop(),
user_agent=USER_AGENT,
quota_project=quota_project,
)

# if user and password are given, use basic auth
Expand Down Expand Up @@ -197,6 +202,7 @@ async def afrom_instance(
user: Optional[str] = None,
password: Optional[str] = None,
ip_type: Union[str, IPTypes] = IPTypes.PUBLIC,
quota_project: Optional[str] = None,
) -> PostgresEngine:
return await cls._create(
project_id,
Expand All @@ -206,6 +212,7 @@ async def afrom_instance(
ip_type,
user,
password,
quota_project=quota_project,
)

@classmethod
Expand Down
1 change: 1 addition & 0 deletions tests/test_postgresql_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,7 @@ async def test_password(
database=db_name,
user=user,
password=password,
quota_project=db_project,
)
assert engine
engine._execute("SELECT 1")
Expand Down
3 changes: 3 additions & 0 deletions tests/test_postgresql_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
class TestLoaderAsync:
@pytest_asyncio.fixture
async def engine(self):
PostgresEngine._connector = None
engine = await PostgresEngine.afrom_instance(
project_id=project_id,
instance=instance_id,
Expand All @@ -48,6 +49,7 @@ async def engine(self):

@pytest_asyncio.fixture
def sync_engine(self):
PostgresEngine._connector = None
engine = PostgresEngine.from_instance(
project_id=project_id,
instance=instance_id,
Expand Down Expand Up @@ -734,6 +736,7 @@ async def test_delete_doc_with_customized_metadata(
assert len(await self._collect_async_items(loader.alazy_load())) == 0

def test_sync_engine(self):
PostgresEngine._connector = None
engine = PostgresEngine.from_instance(
project_id=project_id,
instance=instance_id,
Expand Down